From bb064dddf124bb7a3e68b9e661c72e24a0ac8063 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Mon, 29 Jul 2024 09:33:02 -0700 Subject: [PATCH] Issue/124/mask summarizers (#130) * WIP, first go at doing masked summarizers * WIP, delinting * Fixes to masked summarizers * Added summarize() method overrides * added tests for masked summarizers * Added option to set tomography_bins to 'none' in naive_stack and point_est_hist, and tests for same * Fix up coverage --- src/rail/estimation/algos/equal_count.py | 6 +- src/rail/estimation/algos/naive_stack.py | 93 +++++++++++++++++++-- src/rail/estimation/algos/point_est_hist.py | 91 ++++++++++++++++++-- src/rail/estimation/summarizer.py | 2 +- tests/estimation/test_summarizers.py | 57 ++++++++++++- 5 files changed, 226 insertions(+), 23 deletions(-) diff --git a/src/rail/estimation/algos/equal_count.py b/src/rail/estimation/algos/equal_count.py index f232a45b..5e56a730 100644 --- a/src/rail/estimation/algos/equal_count.py +++ b/src/rail/estimation/algos/equal_count.py @@ -6,7 +6,7 @@ import numpy as np from ceci.config import StageParameter as Param from rail.estimation.classifier import PZClassifier -from rail.core.data import TableHandle +from rail.core.data import Hdf5Handle class EqualCountClassifier(PZClassifier): @@ -27,14 +27,14 @@ class EqualCountClassifier(PZClassifier): nbins=Param(int, 5, msg="Number of tomographic bins"), no_assign=Param(int, -99, msg="Value for no assignment flag"), ) - outputs = [("output", TableHandle)] + outputs = [("output", Hdf5Handle)] def run(self): test_data = self.get_data("input") npdf = test_data.npdf try: - zb = test_data.ancil[self.config.point_estimate] + zb = np.squeeze(test_data.ancil[self.config.point_estimate]) except KeyError as msg: raise KeyError( f"{self.config.point_estimate} is not contained in the data ancil, " diff --git a/src/rail/estimation/algos/naive_stack.py b/src/rail/estimation/algos/naive_stack.py index 0fa2470b..8af43def 100644 --- a/src/rail/estimation/algos/naive_stack.py +++ b/src/rail/estimation/algos/naive_stack.py @@ -8,7 +8,7 @@ from ceci.config import StageParameter as Param from rail.estimation.summarizer import PZSummarizer from rail.estimation.informer import PzInformer -from rail.core.data import QPHandle +from rail.core.data import QPHandle, TableHandle class NaiveStackInformer(PzInformer): @@ -40,8 +40,16 @@ def __init__(self, args, **kwargs): super().__init__(args, **kwargs) self.zgrid = None + def _setup_iterator(self): + itr = self.input_iterator("input") + for s, e, d in itr: + yield s, e, d, np.ones(e-s, dtype=bool) + + def run(self): - iterator = self.input_iterator("input") + handle = self.get_handle("input", allow_missing=True) + self._input_length = handle.size() + iterator = self._setup_iterator() self.zgrid = np.linspace( self.config.zmin, self.config.zmax, self.config.nzbins + 1 ) @@ -51,9 +59,9 @@ def run(self): bootstrap_matrix = self._broadcast_bootstrap_matrix() first = True - for s, e, test_data in iterator: + for s, e, test_data, mask in iterator: print(f"Process {self.rank} running estimator on chunk {s} - {e}") - self._process_chunk(s, e, test_data, first, bootstrap_matrix, yvals, bvals) + self._process_chunk(s, e, test_data, mask, first, bootstrap_matrix, yvals, bvals) first = False if self.comm is not None: # pragma: no cover bvals, yvals = self._join_histograms(bvals, yvals) @@ -66,15 +74,82 @@ def run(self): self.add_data("output", sample_ens) self.add_data("single_NZ", qp_d) - def _process_chunk(self, start, end, data, _first, bootstrap_matrix, yvals, bvals): + def _process_chunk(self, start, end, data, mask, _first, bootstrap_matrix, yvals, bvals): pdf_vals = data.pdf(self.zgrid) + squeeze_mask = np.squeeze(mask) yvals += np.expand_dims( - np.sum(np.where(np.isfinite(pdf_vals), pdf_vals, 0.0), axis=0), 0 + np.sum(np.where(np.isfinite(pdf_vals[squeeze_mask,:]), pdf_vals[squeeze_mask], 0.0), axis=0), 0 ) # qp_d is the normalized probability of the stack, we need to know how many galaxies were for i in range(self.config.nsamples): bootstrap_draws = bootstrap_matrix[:, i] # Neither all of the bootstrap_draws are in this chunk nor the index starts at "start" - mask = (bootstrap_draws >= start) & (bootstrap_draws < end) - bootstrap_draws = bootstrap_draws[mask] - start - bvals[i] += np.sum(pdf_vals[bootstrap_draws], axis=0) + chunk_mask = (bootstrap_draws >= start) & (bootstrap_draws < end) + bootstrap_draws = bootstrap_draws[chunk_mask] - start + zarr = np.where(np.expand_dims(mask, -1), pdf_vals, 0.)[bootstrap_draws] + bvals[i] += np.sum(zarr, axis=0) + + +class NaiveStackMaskedSummarizer(NaiveStackSummarizer): + name = "NaiveStackMaskedSummarizer" + config_options = NaiveStackSummarizer.config_options.copy() + config_options.update( + selected_bin=Param(int, -1, msg="bin to use"), + ) + inputs = [("input", QPHandle), ("tomography_bins", TableHandle)] + outputs = [("output", QPHandle), ("single_NZ", QPHandle)] + + + def _setup_iterator(self): + + selected_bin = self.config.selected_bin + if self.config.tomography_bins in ['none', None]: + selected_bin = -1 + + if selected_bin == -1: + itrs = [self.input_iterator('input')] + else: + itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')] + + for it in zip(*itrs): + first = True + mask = None + for s, e, d in it: + if first: + start = s + end = e + pz_data = d + first = False + else: + mask = d['class_id'] == self.config.selected_bin + if mask is None: + mask = np.ones(pz_data.npdf, dtype=bool) + yield start, end, pz_data, mask + + def summarize(self, input_data, tomo_bins=None): + """Override the Summarizer.summarize() method to take tomo bins + as an additional input + + Parameters + ---------- + input_data : `qp.Ensemble` + Per-galaxy p(z), and any ancilary data associated with it + + tomo_bins : `table-like` + Tomographic bins file + + Returns + ------- + output: `qp.Ensemble` + Ensemble with n(z), and any ancilary data + """ + self.set_data("input", input_data) + if tomo_bins is None: + self.config.tomography_bins = None + else: + self.set_data("tomography_bins", tomo_bins) + self.run() + self.finalize() + return self.get_handle("output") + + diff --git a/src/rail/estimation/algos/point_est_hist.py b/src/rail/estimation/algos/point_est_hist.py index ade6dde0..629648d4 100644 --- a/src/rail/estimation/algos/point_est_hist.py +++ b/src/rail/estimation/algos/point_est_hist.py @@ -8,7 +8,7 @@ from ceci.config import StageParameter as Param from rail.estimation.summarizer import PZSummarizer from rail.estimation.informer import PzInformer -from rail.core.data import QPHandle +from rail.core.data import QPHandle, TableHandle class PointEstHistInformer(PzInformer): @@ -42,8 +42,15 @@ def __init__(self, args, **kwargs): self.zgrid = None self.bincents = None + def _setup_iterator(self): + itr = self.input_iterator("input") + for s, e, d in itr: + yield s, e, d, np.ones(e-s, dtype=bool) + def run(self): - iterator = self.input_iterator("input") + handle = self.get_handle("input", allow_missing=True) + self._input_length = handle.size() + iterator = self._setup_iterator() self.zgrid = np.linspace( self.config.zmin, self.config.zmax, self.config.nzbins + 1 ) @@ -54,10 +61,10 @@ def run(self): hist_vals = np.zeros((self.config.nsamples, self.config.nzbins)) first = True - for s, e, test_data in iterator: + for s, e, test_data, mask in iterator: print(f"Process {self.rank} running estimator on chunk {s} - {e}") self._process_chunk( - s, e, test_data, first, bootstrap_matrix, single_hist, hist_vals + s, e, test_data, mask, first, bootstrap_matrix, single_hist, hist_vals ) first = False if self.comm is not None: # pragma: no cover @@ -74,14 +81,80 @@ def run(self): self.add_data("single_NZ", qp_d) def _process_chunk( - self, start, end, test_data, _first, bootstrap_matrix, single_hist, hist_vals + self, start, end, test_data, mask, _first, bootstrap_matrix, single_hist, hist_vals ): zb = test_data.ancil[self.config.point_estimate] - single_hist += np.histogram(zb, bins=self.zgrid)[0] + single_hist += np.histogram(zb[mask], bins=self.zgrid)[0] for i in range(self.config.nsamples): bootstrap_indeces = bootstrap_matrix[:, i] # Neither all of the bootstrap_draws are in this chunk nor the index starts at "start" - mask = (bootstrap_indeces >= start) & (bootstrap_indeces < end) - bootstrap_indeces = bootstrap_indeces[mask] - start - zarr = zb[bootstrap_indeces] + chunk_mask = (bootstrap_indeces >= start) & (bootstrap_indeces < end) + bootstrap_indeces = bootstrap_indeces[chunk_mask] - start + zarr = np.where(mask, zb, np.nan)[bootstrap_indeces] hist_vals[i] += np.histogram(zarr, bins=self.zgrid)[0] + + +class PointEstHistMaskedSummarizer(PointEstHistSummarizer): + """Summarizer which simply histograms a point estimate""" + + name = "PointEstHistMaskedSummarizer" + config_options = PointEstHistSummarizer.config_options.copy() + config_options.update( + selected_bin=Param(int, -1, msg="bin to use"), + ) + inputs = [("input", QPHandle), ("tomography_bins", TableHandle)] + outputs = [("output", QPHandle), ("single_NZ", QPHandle)] + + def _setup_iterator(self): + + selected_bin = self.config.selected_bin + if self.config.tomography_bins in ['none', None]: + selected_bin = -1 + + if selected_bin == -1: + itrs = [self.input_iterator('input')] + else: + itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')] + + for it in zip(*itrs): + first = True + mask = None + for s, e, d in it: + if first: + start = s + end = e + pz_data = d + first = False + else: + mask = d['class_id'] == self.config.selected_bin + if mask is None: + mask = np.ones(pz_data.npdf, dtype=bool) + yield start, end, pz_data, mask + + def summarize(self, input_data, tomo_bins=None): + """Override the Summarizer.summarize() method to take tomo bins + as an additional input + + Parameters + ---------- + input_data : `qp.Ensemble` + Per-galaxy p(z), and any ancilary data associated with it + + tomo_bins : `table-like` + Tomographic bins file + + Returns + ------- + output: `qp.Ensemble` + Ensemble with n(z), and any ancilary data + """ + self.set_data("input", input_data) + if tomo_bins is None: + self.config.tomography_bins = None + else: + self.set_data("tomography_bins", tomo_bins) + self.run() + self.finalize() + return self.get_handle("output") + + diff --git a/src/rail/estimation/summarizer.py b/src/rail/estimation/summarizer.py index 9d932ba0..4dd65178 100644 --- a/src/rail/estimation/summarizer.py +++ b/src/rail/estimation/summarizer.py @@ -100,7 +100,7 @@ def summarize(self, input_data): self.finalize() return self.get_handle("output") - def _broadcast_bootstrap_matrix(self): + def _broadcast_bootstrap_matrix(self): rng = np.random.default_rng(seed=self.config.seed) # Only one of the nodes needs to produce the bootstrap indices ngal = self._input_length diff --git a/tests/estimation/test_summarizers.py b/tests/estimation/test_summarizers.py index 0814692c..6c65840a 100644 --- a/tests/estimation/test_summarizers.py +++ b/tests/estimation/test_summarizers.py @@ -1,11 +1,12 @@ import os -from rail.core.data import QPHandle +from rail.core.data import QPHandle, TableHandle from rail.core.stage import RailStage from rail.utils.path_utils import RAILDIR from rail.estimation.algos import naive_stack, point_est_hist, var_inf testdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.hdf5") +tomobins = os.path.join(RAILDIR, "rail/examples_data/testdata/output_tomo.hdf5") DS = RailStage.data_store @@ -15,6 +16,7 @@ def one_algo(key, summarizer_class, summary_kwargs): Run summarize """ DS.__class__.allow_overwrite = True + DS.clear() test_data = DS.read_file("test_data", QPHandle, testdata) summarizer = summarizer_class.make_stage(name=key, **summary_kwargs) summary_ens = summarizer.summarize(test_data) @@ -27,6 +29,37 @@ def one_algo(key, summarizer_class, summary_kwargs): return summary_ens +def one_mask_algo(key, summarizer_class, summary_kwargs): + """ + A basic test of running an summaizer subclass + Run summarize + """ + DS.__class__.allow_overwrite = True + DS.clear() + test_data = DS.read_file("test_data", QPHandle, testdata) + tomo_bins = DS.read_file("tomo_bins", TableHandle, tomobins) + + summarizer = summarizer_class.make_stage(name=key, selected_bin=1, **summary_kwargs) + summary_ens = summarizer.summarize(test_data, tomo_bins) + os.remove( + summarizer.get_output(summarizer.get_aliased_tag("output"), final_name=True) + ) + os.remove( + summarizer.get_output(summarizer.get_aliased_tag("single_NZ"), final_name=True) + ) + + summarizer_2 = summarizer_class.make_stage(name=f"{key}_2", **summary_kwargs) + summary_2_ens = summarizer_2.summarize(test_data, None) + os.remove( + summarizer_2.get_output(summarizer_2.get_aliased_tag("output"), final_name=True) + ) + os.remove( + summarizer_2.get_output(summarizer_2.get_aliased_tag("single_NZ"), final_name=True) + ) + + return [summary_ens, summary_2_ens] + + def test_naive_stack(): """Basic end to end test for the Naive stack informer to estimator stages""" naive_stack_informer_stage = naive_stack.NaiveStackInformer.make_stage() @@ -57,3 +90,25 @@ def test_var_inference_stack(): summary_config_dict = {} summarizer_class = var_inf.VarInfStackSummarizer _ = one_algo("VariationalInference", summarizer_class, summary_config_dict) + + +def test_naive_stack_masked(): + """Basic end to end test for the Naive stack informer to estimator stages""" + summary_config_dict = dict( + chunk_size=5, + ) + summarizer_class = naive_stack.NaiveStackMaskedSummarizer + _ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict) + _ = one_algo("NaiveStack", summarizer_class, summary_config_dict) + + +def test_point_estimate_hist_masekd(): + """Basic end to end test for the point estimate histogram informer to estimator + stages + """ + summary_config_dict = dict( + chunk_size=5, + ) + summarizer_class = point_est_hist.PointEstHistMaskedSummarizer + _ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict) + _ = one_algo("PointEstimateHist", summarizer_class, summary_config_dict)