Skip to content

Commit

Permalink
Added option to set tomography_bins to 'none' in naive_stack and poin…
Browse files Browse the repository at this point in the history
…t_est_hist, and tests for same
  • Loading branch information
eacharles committed Jul 22, 2024
1 parent a93b41d commit 7f6725f
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
19 changes: 16 additions & 3 deletions src/rail/estimation/algos/naive_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ class NaiveStackMaskedSummarizer(NaiveStackSummarizer):


def _setup_iterator(self):
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

selected_bin = self.config.selected_bin
if self.config.tomography_bins == 'none':
selected_bin = -1

if selected_bin == -1:
itrs = [self.input_iterator('input')]
else:
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

for it in zip(*itrs):
first = True
Expand All @@ -117,9 +125,11 @@ def _setup_iterator(self):
mask = np.ones(pz_data.npdf, dtype=bool)
else:
mask = d['class_id'] == self.config.selected_bin
if mask is None:
mask = np.ones(pz_data.npdf, dtype=bool)
yield start, end, pz_data, mask

def summarize(self, input_data, tomo_bins):
def summarize(self, input_data, tomo_bins=None):
"""Override the Summarizer.summarize() method to take tomo bins
as an additional input
Expand All @@ -137,7 +147,10 @@ def summarize(self, input_data, tomo_bins):
Ensemble with n(z), and any ancilary data
"""
self.set_data("input", input_data)
self.set_data("tomography_bins", tomo_bins)
if tomo_bins is None:
self.config.tomography_bins = None
else:
self.set_data("tomography_bins", tomo_bins)
self.run()
self.finalize()
return self.get_handle("output")
Expand Down
19 changes: 16 additions & 3 deletions src/rail/estimation/algos/point_est_hist.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,15 @@ class PointEstHistMaskedSummarizer(PointEstHistSummarizer):
outputs = [("output", QPHandle), ("single_NZ", QPHandle)]

def _setup_iterator(self):
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

selected_bin = self.config.selected_bin
if self.config.tomography_bins == 'none':
selected_bin = -1

if selected_bin == -1:
itrs = [self.input_iterator('input')]
else:
itrs = [self.input_iterator('input'), self.input_iterator('tomography_bins')]

for it in zip(*itrs):
first = True
Expand All @@ -122,9 +130,11 @@ def _setup_iterator(self):
mask = np.ones(pz_data.npdf, dtype=bool)
else:
mask = d['class_id'] == self.config.selected_bin
if mask is None:
mask = np.ones(pz_data.npdf, dtype=bool)
yield start, end, pz_data, mask

def summarize(self, input_data, tomo_bins):
def summarize(self, input_data, tomo_bins=None):
"""Override the Summarizer.summarize() method to take tomo bins
as an additional input
Expand All @@ -142,7 +152,10 @@ def summarize(self, input_data, tomo_bins):
Ensemble with n(z), and any ancilary data
"""
self.set_data("input", input_data)
self.set_data("tomography_bins", tomo_bins)
if tomo_bins is None:
self.config.tomography_bins = None
else:
self.set_data("tomography_bins", tomo_bins)
self.run()
self.finalize()
return self.get_handle("output")
Expand Down
3 changes: 3 additions & 0 deletions tests/estimation/test_summarizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def one_mask_algo(key, summarizer_class, summary_kwargs):
os.remove(
summarizer.get_output(summarizer.get_aliased_tag("single_NZ"), final_name=True)
)

return summary_ens


Expand Down Expand Up @@ -87,6 +88,7 @@ def test_naive_stack_masked():
summary_config_dict = {}
summarizer_class = naive_stack.NaiveStackMaskedSummarizer
_ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict)
_ = one_algo("NaiveStack", summarizer_class, summary_config_dict)


def test_point_estimate_hist_masekd():
Expand All @@ -96,3 +98,4 @@ def test_point_estimate_hist_masekd():
summary_config_dict = {}
summarizer_class = point_est_hist.PointEstHistMaskedSummarizer
_ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict)
_ = one_algo("PointEstimateHist", summarizer_class, summary_config_dict)

0 comments on commit 7f6725f

Please sign in to comment.