From 11800c79a25c5cfc3345cbf8daafbcb62f19d220 Mon Sep 17 00:00:00 2001 From: Eric Charles Date: Mon, 20 May 2024 08:30:51 -0700 Subject: [PATCH] Issue/108/evaluator mods (#109) * protect dist_to_point against hdf5_groupname being none * fix variables names and allow usine allow_missing=True in run() * add use of exclude_metrics and allow for multiple point estimates * fix up unit test to account for setting limits * protect metrics that don't take limits * fix location of key definition to avoid using undefined variable --- .../evaluation/dist_to_point_evaluator.py | 6 ++-- src/rail/evaluation/evaluator.py | 34 +++++++++++++------ src/rail/evaluation/single_evaluator.py | 19 ++++++----- tests/evaluation/test_evaluation.py | 3 +- 4 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/rail/evaluation/dist_to_point_evaluator.py b/src/rail/evaluation/dist_to_point_evaluator.py index c72ea8f0..712f77e2 100644 --- a/src/rail/evaluation/dist_to_point_evaluator.py +++ b/src/rail/evaluation/dist_to_point_evaluator.py @@ -61,6 +61,8 @@ def _process_chunk(self, data_tuple, first): def _process_all(self, data_tuple): estimate_data = data_tuple[0] - reference_data = data_tuple[1][self.config.hdf5_groupname][self.config.reference_dictionary_key] - + if self.config.hdf5_groupname: + reference_data = data_tuple[1][self.config.hdf5_groupname][self.config.reference_dictionary_key] + else: + reference_data = data_tuple[1][self.config.reference_dictionary_key] self._process_all_metrics(estimate_data, reference_data) diff --git a/src/rail/evaluation/evaluator.py b/src/rail/evaluation/evaluator.py index 588011c7..d8a58a6f 100644 --- a/src/rail/evaluation/evaluator.py +++ b/src/rail/evaluation/evaluator.py @@ -158,19 +158,23 @@ def finalize(self): for metric, cached_metric in self._cached_metrics.items(): if cached_metric.metric_output_type != MetricOutputType.single_value and cached_metric.metric_output_type != MetricOutputType.single_distribution: continue - - if metric not in self._cached_data: - print(f"Skipping {metric} which did not cache data") + matching_keys = [] + for key_ in self._cached_data.keys(): + if key_.find(metric) == 0: + matching_keys.append(key_) + if not matching_keys: + print(f"Skipping {metric} which did not cache data {list(self._cached_data.keys())}") continue - if self.comm: # pragma: no cover - self._cached_data[metric] = self.comm.gather(self._cached_data[metric]) + for key_ in matching_keys: + if self.comm: # pragma: no cover + self._cached_data[key_] = self.comm.gather(self._cached_data[key_]) - if cached_metric.metric_output_type == MetricOutputType.single_value: - summary_data[metric] = np.array([cached_metric.finalize(self._cached_data[metric])]) + if cached_metric.metric_output_type == MetricOutputType.single_value: + summary_data[key_] = np.array([cached_metric.finalize(self._cached_data[key_])]) - elif cached_metric.metric_output_type == MetricOutputType.single_distribution: - # we expect `cached_metric.finalize` to return a qp.Ensemble - single_distribution_summary_data[metric] = cached_metric.finalize(self._cached_data[metric]) + elif cached_metric.metric_output_type == MetricOutputType.single_distribution: + # we expect `cached_metric.finalize` to return a qp.Ensemble + single_distribution_summary_data[key_] = cached_metric.finalize(self._cached_data[key_]) self._summary_handle = self.add_handle('summary', data=summary_data) self._single_distribution_summary_handle = self.add_handle('single_distribution_summary', data=single_distribution_summary_data) @@ -329,6 +333,8 @@ def _build_config_dict(self): if "all" in self.config.metrics: # pragma: no cover metric_list = list(self._metric_dict.keys()) + for exclude_ in self.config.exclude_metrics: + metric_list.remove(exclude_) else: metric_list = self.config.metrics @@ -344,9 +350,15 @@ def _build_config_dict(self): sub_dict = self.config.metric_config.get("general", {}).copy() sub_dict.update(self.config.metric_config.get(metric_name_, {})) + if 'limits' in self.config: + sub_dict.update(dict(limits=self.config.limits)) self._metric_config_dict[metric_name_] = sub_dict this_metric_class = self._metric_dict[metric_name_] - this_metric = this_metric_class(**sub_dict) + try: + this_metric = this_metric_class(**sub_dict) + except (TypeError, KeyError): + sub_dict.pop('limits') + this_metric = this_metric_class(**sub_dict) self._cached_metrics[metric_name_] = this_metric diff --git a/src/rail/evaluation/single_evaluator.py b/src/rail/evaluation/single_evaluator.py index 20d4127e..b9a5b479 100644 --- a/src/rail/evaluation/single_evaluator.py +++ b/src/rail/evaluation/single_evaluator.py @@ -7,6 +7,7 @@ import numpy as np from ceci.config import StageParameter as Param +import qp.metrics from qp.metrics import MetricInputType, MetricOutputType from qp.metrics.base_metric_classes import BaseMetric @@ -51,8 +52,8 @@ def run(self): # pylint: disable=too-many-branches Get the truth data from the data store under this stages 'truth' tag Puts the data into the data store under this stages 'output' tag """ - input_data_handle = self.get_handle("input") - truth_data_handle = self.get_handle("truth") + input_data_handle = self.get_handle("input", allow_missing=True) + truth_data_handle = self.get_handle("truth", allow_missing=True) self._input_data_type = input_data_handle.check_pdf_or_point() self._truth_data_type = truth_data_handle.check_pdf_or_point() @@ -114,9 +115,9 @@ def _process_chunk(self, data_tuple, first): ) continue for point_estimate_ in self.config.point_estimates: - key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}" point_data = np.squeeze(input_data.ancil[point_estimate_]) - for truth_point_estimate_ in self.config.truth_point_estimates: + for truth_point_estimate_ in self.config.truth_point_estimates: + key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}" self._process_chunk_point_to_point( this_metric, key_val, @@ -187,9 +188,9 @@ def _process_all(self, data_tuple): ): # pragma: no cover continue for point_estimate_ in self.config.point_estimates: - key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}" point_data = input_data.ancil[point_estimate_] for truth_point_estimate_ in self.config.truth_point_estimates: + key_val = f"{metric}_{point_estimate_}_{truth_point_estimate_}" self._process_all_point_to_point( this_metric, key_val, @@ -276,7 +277,7 @@ def _process_chunk_dist_to_dist(self, this_metric, key, input_data, truth_data): print(f"{metric} with output type MetricOutputType.single_value does not support parallel processing yet") return - accumulated_data = this_metric.accumulate(estimate_data, reference_data) + accumulated_data = this_metric.accumulate(input_data, truth_data) if self.comm: self._cached_data[key] = accumulated_data else: @@ -313,7 +314,7 @@ def _process_chunk_dist_to_point(self, this_metric, key, input_data, truth_data) "single_value does not support parallel processing yet" ) return - accumulated_data = this_metric.accumulate(estimate_data, reference_data) + accumulated_data = this_metric.accumulate(input_data, truth_data) if self.comm: self._cached_data[key] = accumulated_data else: @@ -350,7 +351,7 @@ def _process_chunk_point_to_dist(self, this_metric, key, input_data, truth_data) "single_value does not support parallel processing yet" ) return - accumulated_data = this_metric.accumulate(estimate_data, reference_data) + accumulated_data = this_metric.accumulate(input_data, truth_data) if self.comm: self._cached_data[key] = accumulated_data else: @@ -386,7 +387,7 @@ def _process_chunk_point_to_point(self, this_metric, key, input_data, truth_data "single_value does not support parallel processing yet" ) return - accumulated_data = this_metric.accumulate(estimate_data, reference_data) + accumulated_data = this_metric.accumulate(input_data, truth_data) if self.comm: self._cached_data[key] = accumulated_data else: diff --git a/tests/evaluation/test_evaluation.py b/tests/evaluation/test_evaluation.py index 30d22c17..d1845c2a 100644 --- a/tests/evaluation/test_evaluation.py +++ b/tests/evaluation/test_evaluation.py @@ -121,7 +121,8 @@ def test_dist_to_point_evaluator(): _random_state=None, metric_config={ 'brier': {'limits':(0,3.1)}, - } + }, + limits=[0., 3.1], ) ensemble = DS.read_file(key='pdfs_data', handle_class=QPHandle, path=pdfs_file)