From 06847f185e615ba443fbc0e126a0ec41ee8c02c0 Mon Sep 17 00:00:00 2001 From: Drew Oldag Date: Wed, 17 Jan 2024 14:47:54 -0800 Subject: [PATCH] Updated logic in `calculate_point_estimates` to conditionally create or append to ancil dictionary. --- src/rail/core/point_estimation.py | 5 +++- tests/core/test_point_estimation.py | 38 +++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/src/rail/core/point_estimation.py b/src/rail/core/point_estimation.py index b9a98d1d..a47da030 100644 --- a/src/rail/core/point_estimation.py +++ b/src/rail/core/point_estimation.py @@ -52,7 +52,10 @@ def calculate_point_estimates(self, qp_dist, grid=None): ancil_dict.update(median = median_value) if calculated_point_estimates: - qp_dist.set_ancil(ancil_dict) + if qp_dist.ancil is None: + qp_dist.set_ancil(ancil_dict) + else: + qp_dist.add_to_ancil(ancil_dict) return qp_dist diff --git a/tests/core/test_point_estimation.py b/tests/core/test_point_estimation.py index b09e63de..1d95f78d 100644 --- a/tests/core/test_point_estimation.py +++ b/tests/core/test_point_estimation.py @@ -1,3 +1,4 @@ +# pylint: disable=no-member import pytest import numpy as np @@ -81,3 +82,40 @@ def test_mode_no_point_estimates(): output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None) assert output_ensemble.ancil is None + +def test_keep_existing_ancil_data(): + """Make sure that we don't overwrite the ancil data if it already exists. + """ + config_dict = {'zmin':0.0, 'zmax': 3.0, 'nzbins':100, 'calculated_point_estimates': ['mode']} + + test_estimator = CatEstimator.make_stage(name='test', **config_dict) + + locs = 2* (np.random.uniform(size=(100,1))-0.5) + scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) + test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) + + test_ensemble.set_ancil({'foo': np.zeros(100)}) + + output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None) + + assert 'foo' in output_ensemble.ancil + assert np.all(output_ensemble.ancil['foo'] == 0.0) + assert len(output_ensemble.ancil['foo']) == 100 + +def test_write_new_ancil_data(): + """Make sure that we don't overwrite the ancil data if it already exists. + """ + config_dict = {'zmin':0.0, 'zmax': 3.0, 'nzbins':100, 'calculated_point_estimates': ['mode']} + + test_estimator = CatEstimator.make_stage(name='test', **config_dict) + + locs = 2* (np.random.uniform(size=(100,1))-0.5) + scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5) + test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) + + test_ensemble.set_ancil({'foo': np.zeros(100)}) + + output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None) + + assert 'mode' in output_ensemble.ancil + assert len(output_ensemble.ancil['mode']) == 100