Skip to content

Commit

Permalink
added tests for masked summarizers
Browse files Browse the repository at this point in the history
  • Loading branch information
eacharles committed Jul 22, 2024
1 parent 5d5786a commit a93b41d
Showing 1 changed file with 40 additions and 1 deletion.
41 changes: 40 additions & 1 deletion tests/estimation/test_summarizers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import os

from rail.core.data import QPHandle
from rail.core.data import QPHandle, TableHandle
from rail.core.stage import RailStage
from rail.utils.path_utils import RAILDIR
from rail.estimation.algos import naive_stack, point_est_hist, var_inf

testdata = os.path.join(RAILDIR, "rail/examples_data/testdata/output_BPZ_lite.hdf5")
tomobins = os.path.join(RAILDIR, "rail/examples_data/testdata/output_tomo.hdf5")
DS = RailStage.data_store


Expand All @@ -15,6 +16,7 @@ def one_algo(key, summarizer_class, summary_kwargs):
Run summarize
"""
DS.__class__.allow_overwrite = True
DS.clear()
test_data = DS.read_file("test_data", QPHandle, testdata)
summarizer = summarizer_class.make_stage(name=key, **summary_kwargs)
summary_ens = summarizer.summarize(test_data)
Expand All @@ -27,6 +29,27 @@ def one_algo(key, summarizer_class, summary_kwargs):
return summary_ens


def one_mask_algo(key, summarizer_class, summary_kwargs):
"""
A basic test of running an summaizer subclass
Run summarize
"""
DS.__class__.allow_overwrite = True
DS.clear()
test_data = DS.read_file("test_data", QPHandle, testdata)
tomo_bins = DS.read_file("tomo_bins", TableHandle, tomobins)

summarizer = summarizer_class.make_stage(name=key, **summary_kwargs)
summary_ens = summarizer.summarize(test_data, tomo_bins)
os.remove(
summarizer.get_output(summarizer.get_aliased_tag("output"), final_name=True)
)
os.remove(
summarizer.get_output(summarizer.get_aliased_tag("single_NZ"), final_name=True)
)
return summary_ens


def test_naive_stack():
"""Basic end to end test for the Naive stack informer to estimator stages"""
naive_stack_informer_stage = naive_stack.NaiveStackInformer.make_stage()
Expand Down Expand Up @@ -57,3 +80,19 @@ def test_var_inference_stack():
summary_config_dict = {}
summarizer_class = var_inf.VarInfStackSummarizer
_ = one_algo("VariationalInference", summarizer_class, summary_config_dict)


def test_naive_stack_masked():
"""Basic end to end test for the Naive stack informer to estimator stages"""
summary_config_dict = {}
summarizer_class = naive_stack.NaiveStackMaskedSummarizer
_ = one_mask_algo("NaiveStack", summarizer_class, summary_config_dict)


def test_point_estimate_hist_masekd():
"""Basic end to end test for the point estimate histogram informer to estimator
stages
"""
summary_config_dict = {}
summarizer_class = point_est_hist.PointEstHistMaskedSummarizer
_ = one_mask_algo("PointEstimateHist", summarizer_class, summary_config_dict)

0 comments on commit a93b41d

Please sign in to comment.