Skip to content

Commit

Permalink
Issue/108/evaluator mods (#109)
Browse files Browse the repository at this point in the history
* protect dist_to_point against hdf5_groupname being none

* fix variables names and allow usine allow_missing=True in run()

* add use of exclude_metrics and allow for multiple point estimates

* fix up unit test to account for setting limits

* protect metrics that don't take limits

* fix location of key definition to avoid using undefined variable
  • Loading branch information
eacharles authored May 20, 2024
1 parent eb4f55a commit 11800c7
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
6 changes: 4 additions & 2 deletions src/rail/evaluation/dist_to_point_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def _process_chunk(self, data_tuple, first):

def _process_all(self, data_tuple):
estimate_data = data_tuple[0]
reference_data = data_tuple[1][self.config.hdf5_groupname][self.config.reference_dictionary_key]

if self.config.hdf5_groupname:
reference_data = data_tuple[1][self.config.hdf5_groupname][self.config.reference_dictionary_key]
else:
reference_data = data_tuple[1][self.config.reference_dictionary_key]
self._process_all_metrics(estimate_data, reference_data)
34 changes: 23 additions & 11 deletions src/rail/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,19 +158,23 @@ def finalize(self):
for metric, cached_metric in self._cached_metrics.items():
if cached_metric.metric_output_type != MetricOutputType.single_value and cached_metric.metric_output_type != MetricOutputType.single_distribution:
continue

if metric not in self._cached_data:
print(f"Skipping {metric} which did not cache data")
matching_keys = []
for key_ in self._cached_data.keys():
if key_.find(metric) == 0:
matching_keys.append(key_)
if not matching_keys:
print(f"Skipping {metric} which did not cache data {list(self._cached_data.keys())}")
continue
if self.comm: # pragma: no cover
self._cached_data[metric] = self.comm.gather(self._cached_data[metric])
for key_ in matching_keys:
if self.comm: # pragma: no cover
self._cached_data[key_] = self.comm.gather(self._cached_data[key_])

if cached_metric.metric_output_type == MetricOutputType.single_value:
summary_data[metric] = np.array([cached_metric.finalize(self._cached_data[metric])])
if cached_metric.metric_output_type == MetricOutputType.single_value:
summary_data[key_] = np.array([cached_metric.finalize(self._cached_data[key_])])

elif cached_metric.metric_output_type == MetricOutputType.single_distribution:
# we expect `cached_metric.finalize` to return a qp.Ensemble
single_distribution_summary_data[metric] = cached_metric.finalize(self._cached_data[metric])
elif cached_metric.metric_output_type == MetricOutputType.single_distribution:
# we expect `cached_metric.finalize` to return a qp.Ensemble
single_distribution_summary_data[key_] = cached_metric.finalize(self._cached_data[key_])

self._summary_handle = self.add_handle('summary', data=summary_data)
self._single_distribution_summary_handle = self.add_handle('single_distribution_summary', data=single_distribution_summary_data)
Expand Down Expand Up @@ -329,6 +333,8 @@ def _build_config_dict(self):

if "all" in self.config.metrics: # pragma: no cover
metric_list = list(self._metric_dict.keys())
for exclude_ in self.config.exclude_metrics:
metric_list.remove(exclude_)
else:
metric_list = self.config.metrics

Expand All @@ -344,9 +350,15 @@ def _build_config_dict(self):

sub_dict = self.config.metric_config.get("general", {}).copy()
sub_dict.update(self.config.metric_config.get(metric_name_, {}))
if 'limits' in self.config:
sub_dict.update(dict(limits=self.config.limits))
self._metric_config_dict[metric_name_] = sub_dict
this_metric_class = self._metric_dict[metric_name_]
this_metric = this_metric_class(**sub_dict)
try:
this_metric = this_metric_class(**sub_dict)
except (TypeError, KeyError):
sub_dict.pop('limits')
this_metric = this_metric_class(**sub_dict)
self._cached_metrics[metric_name_] = this_metric


Expand Down
19 changes: 10 additions & 9 deletions src/rail/evaluation/single_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np

from ceci.config import StageParameter as Param
import qp.metrics
from qp.metrics import MetricInputType, MetricOutputType
from qp.metrics.base_metric_classes import BaseMetric

Expand Down Expand Up @@ -51,8 +52,8 @@ def run(self): # pylint: disable=too-many-branches
Get the truth data from the data store under this stages 'truth' tag
Puts the data into the data store under this stages 'output' tag
"""
input_data_handle = self.get_handle("input")
truth_data_handle = self.get_handle("truth")
input_data_handle = self.get_handle("input", allow_missing=True)
truth_data_handle = self.get_handle("truth", allow_missing=True)

self._input_data_type = input_data_handle.check_pdf_or_point()
self._truth_data_type = truth_data_handle.check_pdf_or_point()
Expand Down Expand Up @@ -114,9 +115,9 @@ def _process_chunk(self, data_tuple, first):
)
continue
for point_estimate_ in self.config.point_estimates:
key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}"
point_data = np.squeeze(input_data.ancil[point_estimate_])
for truth_point_estimate_ in self.config.truth_point_estimates:
for truth_point_estimate_ in self.config.truth_point_estimates:
key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}"
self._process_chunk_point_to_point(
this_metric,
key_val,
Expand Down Expand Up @@ -187,9 +188,9 @@ def _process_all(self, data_tuple):
): # pragma: no cover
continue
for point_estimate_ in self.config.point_estimates:
key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}"
point_data = input_data.ancil[point_estimate_]
for truth_point_estimate_ in self.config.truth_point_estimates:
key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}"
self._process_all_point_to_point(
this_metric,
key_val,
Expand Down Expand Up @@ -276,7 +277,7 @@ def _process_chunk_dist_to_dist(self, this_metric, key, input_data, truth_data):
print(f"{metric} with output type MetricOutputType.single_value does not support parallel processing yet")
return

accumulated_data = this_metric.accumulate(estimate_data, reference_data)
accumulated_data = this_metric.accumulate(input_data, truth_data)
if self.comm:
self._cached_data[key] = accumulated_data
else:
Expand Down Expand Up @@ -313,7 +314,7 @@ def _process_chunk_dist_to_point(self, this_metric, key, input_data, truth_data)
"single_value does not support parallel processing yet"
)
return
accumulated_data = this_metric.accumulate(estimate_data, reference_data)
accumulated_data = this_metric.accumulate(input_data, truth_data)
if self.comm:
self._cached_data[key] = accumulated_data
else:
Expand Down Expand Up @@ -350,7 +351,7 @@ def _process_chunk_point_to_dist(self, this_metric, key, input_data, truth_data)
"single_value does not support parallel processing yet"
)
return
accumulated_data = this_metric.accumulate(estimate_data, reference_data)
accumulated_data = this_metric.accumulate(input_data, truth_data)
if self.comm:
self._cached_data[key] = accumulated_data
else:
Expand Down Expand Up @@ -386,7 +387,7 @@ def _process_chunk_point_to_point(self, this_metric, key, input_data, truth_data
"single_value does not support parallel processing yet"
)
return
accumulated_data = this_metric.accumulate(estimate_data, reference_data)
accumulated_data = this_metric.accumulate(input_data, truth_data)
if self.comm:
self._cached_data[key] = accumulated_data
else:
Expand Down
3 changes: 2 additions & 1 deletion tests/evaluation/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def test_dist_to_point_evaluator():
_random_state=None,
metric_config={
'brier': {'limits':(0,3.1)},
}
},
limits=[0., 3.1],
)

ensemble = DS.read_file(key='pdfs_data', handle_class=QPHandle, path=pdfs_file)
Expand Down

0 comments on commit 11800c7

Please sign in to comment.