From 7f6725f63dca7448b0e2e5e3bd3d9ad6c2062983 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Wed, 3 Jul 2024 08:48:14 -0700 Subject: [PATCH] Added option to set tomography_bins to 'none' in naive_stack and point_est_hist, and tests for same --- src/rail/estimation/algos/naive_stack.py | 19 ++++++++++++++++--- src/rail/estimation/algos/point_est_hist.py | 19 ++++++++++++++++--- tests/estimation/test_summarizers.py | 3 +++ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/src/rail/estimation/algos/naive_stack.py b/src/rail/estimation/algos/naive_stack.py index 89128d4d..7729cf42 100644 --- a/src/rail/estimation/algos/naive_stack.py +++ b/src/rail/estimation/algos/naive_stack.py @@ -101,7 +101,15 @@ class NaiveStackMaskedSummarizer(NaiveStackSummarizer): def _setup_iterator(self): - itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')] + + selected_bin = self.config.selected_bin + if self.config.tomography_bins == 'none': + selected_bin = -1 + + if selected_bin == -1: + itrs = [self.input_iterator('input')] + else: + itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')] for it in zip(*itrs): first = True @@ -117,9 +125,11 @@ def _setup_iterator(self): mask = np.ones(pz_data.npdf, dtype=bool) else: mask = d['class_id'] == self.config.selected_bin + if mask is None: + mask = np.ones(pz_data.npdf, dtype=bool) yield start, end, pz_data, mask - def summarize(self, input_data, tomo_bins): + def summarize(self, input_data, tomo_bins=None): """Override the Summarizer.summarize() method to take tomo bins as an additional input @@ -137,7 +147,10 @@ def summarize(self, input_data, tomo_bins): Ensemble with n(z), and any ancilary data """ self.set_data("input", input_data) - self.set_data("tomography_bins", tomo_bins) + if tomo_bins is None: + self.config.tomography_bins = None + else: + self.set_data("tomography_bins", tomo_bins) self.run() self.finalize() return self.get_handle("output") diff --git a/src/rail/estimation/algos/point_est_hist.py b/src/rail/estimation/algos/point_est_hist.py index 852a9f09..cc94e9bb 100644 --- a/src/rail/estimation/algos/point_est_hist.py +++ b/src/rail/estimation/algos/point_est_hist.py @@ -106,7 +106,15 @@ class PointEstHistMaskedSummarizer(PointEstHistSummarizer): outputs = [("output", QPHandle), ("single_NZ", QPHandle)] def _setup_iterator(self): - itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')] + + selected_bin = self.config.selected_bin + if self.config.tomography_bins == 'none': + selected_bin = -1 + + if selected_bin == -1: + itrs = [self.input_iterator('input')] + else: + itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')] for it in zip(*itrs): first = True @@ -122,9 +130,11 @@ def _setup_iterator(self): mask = np.ones(pz_data.npdf, dtype=bool) else: mask = d['class_id'] == self.config.selected_bin + if mask is None: + mask = np.ones(pz_data.npdf, dtype=bool) yield start, end, pz_data, mask - def summarize(self, input_data, tomo_bins): + def summarize(self, input_data, tomo_bins=None): """Override the Summarizer.summarize() method to take tomo bins as an additional input @@ -142,7 +152,10 @@ def summarize(self, input_data, tomo_bins): Ensemble with n(z), and any ancilary data """ self.set_data("input", input_data) - self.set_data("tomography_bins", tomo_bins) + if tomo_bins is None: + self.config.tomography_bins = None + else: + self.set_data("tomography_bins", tomo_bins) self.run() self.finalize() return self.get_handle("output") diff --git a/tests/estimation/test_summarizers.py b/tests/estimation/test_summarizers.py index f2609789..ebbe89d6 100644 --- a/tests/estimation/test_summarizers.py +++ b/tests/estimation/test_summarizers.py @@ -47,6 +47,7 @@ def one_mask_algo(key, summarizer_class, summary_kwargs): os.remove( summarizer.get_output(summarizer.get_aliased_tag("single_NZ"), final_name=True) ) + return summary_ens @@ -87,6 +88,7 @@ def test_naive_stack_masked(): summary_config_dict = {} summarizer_class = naive_stack.NaiveStackMaskedSummarizer _ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict) + _ = one_algo("NaiveStack", summarizer_class, summary_config_dict) def test_point_estimate_hist_masekd(): @@ -96,3 +98,4 @@ def test_point_estimate_hist_masekd(): summary_config_dict = {} summarizer_class = point_est_hist.PointEstHistMaskedSummarizer _ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict) + _ = one_algo("PointEstimateHist", summarizer_class, summary_config_dict)