Skip to content

Commit

Permalink
Updated logic in calculate_point_estimates to conditionally create …
Browse files Browse the repository at this point in the history
…or append to ancil dictionary.
  • Loading branch information
drewoldag committed Jan 17, 2024
1 parent 88953e8 commit 06847f1
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/rail/core/point_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def calculate_point_estimates(self, qp_dist, grid=None):
ancil_dict.update(median = median_value)

if calculated_point_estimates:
qp_dist.set_ancil(ancil_dict)
if qp_dist.ancil is None:
qp_dist.set_ancil(ancil_dict)
else:
qp_dist.add_to_ancil(ancil_dict)

return qp_dist

Expand Down
38 changes: 38 additions & 0 deletions tests/core/test_point_estimation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=no-member
import pytest
import numpy as np

Expand Down Expand Up @@ -81,3 +82,40 @@ def test_mode_no_point_estimates():
output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None)

assert output_ensemble.ancil is None

def test_keep_existing_ancil_data():
"""Make sure that we don't overwrite the ancil data if it already exists.
"""
config_dict = {'zmin':0.0, 'zmax': 3.0, 'nzbins':100, 'calculated_point_estimates': ['mode']}

test_estimator = CatEstimator.make_stage(name='test', **config_dict)

locs = 2* (np.random.uniform(size=(100,1))-0.5)
scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5)
test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales))

test_ensemble.set_ancil({'foo': np.zeros(100)})

output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None)

assert 'foo' in output_ensemble.ancil
assert np.all(output_ensemble.ancil['foo'] == 0.0)
assert len(output_ensemble.ancil['foo']) == 100

def test_write_new_ancil_data():
"""Make sure that we don't overwrite the ancil data if it already exists.
"""
config_dict = {'zmin':0.0, 'zmax': 3.0, 'nzbins':100, 'calculated_point_estimates': ['mode']}

test_estimator = CatEstimator.make_stage(name='test', **config_dict)

locs = 2* (np.random.uniform(size=(100,1))-0.5)
scales = 1 + 0.2*(np.random.uniform(size=(100,1))-0.5)
test_ensemble = qp.Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales))

test_ensemble.set_ancil({'foo': np.zeros(100)})

output_ensemble = test_estimator.calculate_point_estimates(test_ensemble, None)

assert 'mode' in output_ensemble.ancil
assert len(output_ensemble.ancil['mode']) == 100

0 comments on commit 06847f1

Please sign in to comment.