Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added metric factory #209

Merged
merged 6 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/qp/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
from .array_metrics import *
from .metrics import *
from .goodness_of_fit import *
from .base_metric_classes import *

# added for testing purposes

from .metrics import _calculate_grid_parameters, _check_ensemble_is_not_nested, _check_ensembles_are_same_size
from .metrics import _check_ensembles_contain_correct_number_of_distributions

from .factory import MetricFactory


create_metric = MetricFactory.create_metric
print_metrics = MetricFactory.print_metrics
list_metrics = MetricFactory.list_metrics
update_metrics = MetricFactory.update_metrics
6 changes: 3 additions & 3 deletions src/qp/metrics/base_metric_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def uses_point_for_reference(self) -> bool:
]


class MetricOuputType(enum.Enum):
class MetricOutputType(enum.Enum):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤦 <- I'm glad there's a facepalm emoji in slack.

"""Defines the various output types that metric classes can return."""

unknown = -1
Expand All @@ -78,7 +78,7 @@ class BaseMetric(ABC):
MetricInputType.unknown
) # The type of input data expected for this metric
metric_output_type = (
MetricOuputType.unknown
MetricOutputType.unknown
) # The form of the output data from this metric

def __init__(self, limits: tuple = (0.0, 3.0), dx: float = 0.01) -> None:
Expand Down Expand Up @@ -112,7 +112,7 @@ class SingleEnsembleMetric(BaseMetric):
"""A base class for metrics that accept only a single ensemble as input."""

metric_input_type = MetricInputType.single_ensemble
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def evaluate(self, estimate):
raise NotImplementedError()
Expand Down
22 changes: 11 additions & 11 deletions src/qp/metrics/concrete_metric_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from qp.ensemble import Ensemble
from qp.metrics.base_metric_classes import (
MetricOuputType,
MetricOutputType,
DistToDistMetric,
DistToPointMetric,
SingleEnsembleMetric,
Expand Down Expand Up @@ -44,7 +44,7 @@ class KLDMetric(DistToDistMetric):
"""Class wrapper around the KLD metric"""

metric_name = "kld"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(self, limits: tuple = (0.0, 3.0), dx: float = 0.01, **kwargs) -> None:
super().__init__(limits, dx)
Expand All @@ -57,7 +57,7 @@ class RMSEMetric(DistToDistMetric):
"""Class wrapper around the Root Mean Square Error metric"""

metric_name = "rmse"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(self, limits: tuple = (0.0, 3.0), dx: float = 0.01, **kwargs) -> None:
super().__init__(limits, dx)
Expand All @@ -70,7 +70,7 @@ class RBPEMetric(SingleEnsembleMetric):
"""Class wrapper around the Risk Based Point Estimate metric."""

metric_name = "rbpe"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(self, limits: tuple = (np.inf, np.inf), **kwargs) -> None:
super().__init__(limits)
Expand All @@ -86,7 +86,7 @@ class BrierMetric(DistToPointMetric):
"""

metric_name = "brier"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(self, limits: tuple = (0.0, 3.0), dx: float = 0.01, **kwargs) -> None:
super().__init__(limits, dx)
Expand All @@ -99,7 +99,7 @@ class OutlierMetric(SingleEnsembleMetric):
"""Class wrapper around the outlier calculation metric."""

metric_name = "outlier"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(self, cdf_limits: tuple = (0.0001, 0.9999), **kwargs) -> None:
super().__init__()
Expand All @@ -115,7 +115,7 @@ class ADMetric(DistToDistMetric):
"""Class wrapper for Anderson Darling metric."""

metric_name = "ad"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(
self, num_samples: int = 100, _random_state: float = None, **kwargs
Expand Down Expand Up @@ -146,7 +146,7 @@ class CvMMetric(DistToDistMetric):
"""Class wrapper for Cramer von Mises metric."""

metric_name = "cvm"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(
self, num_samples: int = 100, _random_state: float = None, **kwargs
Expand Down Expand Up @@ -177,7 +177,7 @@ class KSMetric(DistToDistMetric):
"""Class wrapper for Kolmogorov Smirnov metric."""

metric_name = "ks"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(
self, num_samples: int = 100, _random_state: float = None, **kwargs
Expand Down Expand Up @@ -209,7 +209,7 @@ class PITMetric(DistToPointMetric):
"""Class wrapper for the PIT Metric class."""

metric_name = "pit"
metric_output_type = MetricOuputType.single_distribution
metric_output_type = MetricOutputType.single_distribution
default_eval_grid = np.linspace(0, 1, 100)

def __init__(self, eval_grid: list = default_eval_grid, **kwargs) -> None:
Expand All @@ -225,7 +225,7 @@ class CDELossMetric(DistToPointMetric):
"""Conditional density loss"""

metric_name = "cdeloss"
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution
default_eval_grid = np.linspace(0, 2.5, 301)

def __init__(self, eval_grid: list = default_eval_grid, **kwargs) -> None:
Expand Down
50 changes: 50 additions & 0 deletions src/qp/metrics/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from .base_metric_classes import BaseMetric


def get_all_subclasses(cls):
""" Utility function to recursively get all the subclasses of a class """
all_subclasses = []

for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(get_all_subclasses(subclass))

return all_subclasses


class MetricFactory:

metric_dict = {}

@classmethod
def update_metrics(cls):
""" Update the dictionary of known metrics """
all_subclasses = get_all_subclasses(BaseMetric)
cls.metric_dict = { subcls.metric_name.replace("Metric", ""): subcls for subcls in all_subclasses if subcls.metric_name }

@classmethod
def print_metrics(cls, force_update=False):
""" List all the known metrics """
if not cls.metric_dict or force_update:
cls.update_metrics()
print('List of available metrics')
for key, val in cls.metric_dict.items():
print(key, val, val.metric_input_type.name)

@classmethod
def list_metrics(cls, force_update=False):
""" Get the list of all the metric names """
if not cls.metric_dict or force_update:
cls.update_metrics()
return list(cls.metric_dict.keys())

@classmethod
def create_metric(cls, name, force_update=False, **kwargs):
""" Create a metric evaluator """
if not cls.metric_dict or force_update:
cls.update_metrics()
try:
metric_class = cls.metric_dict[name]
except KeyError as msg:
raise KeyError(f"{name} is not in the set of known metrics {str(cls.list_metrics())}") from msg
return metric_class(**kwargs)
12 changes: 6 additions & 6 deletions src/qp/metrics/point_estimate_metric_classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from qp.metrics.base_metric_classes import (
MetricOuputType,
MetricOutputType,
PointToPointMetric,
)

Expand All @@ -12,7 +12,7 @@ class PointStatsEz(PointToPointMetric):
metric_name = "point_stats_ez"

#! This doesn't seem quiet correct, perhaps we need a `single_value_per_input_element` ???
metric_output_type = MetricOuputType.one_value_per_distribution
metric_output_type = MetricOutputType.one_value_per_distribution

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -41,7 +41,7 @@ class PointSigmaIQR(PointToPointMetric):
"""Calculate sigmaIQR"""

metric_name = "point_stats_iqr"
metric_output_type = MetricOuputType.single_value
metric_output_type = MetricOutputType.single_value

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -75,7 +75,7 @@ class PointBias(PointToPointMetric):
"""

metric_name = "point_bias"
metric_output_type = MetricOuputType.single_value
metric_output_type = MetricOutputType.single_value

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -106,7 +106,7 @@ class PointOutlierRate(PointToPointMetric):
"""

metric_name = "point_outlier_rate"
metric_output_type = MetricOuputType.single_value
metric_output_type = MetricOutputType.single_value

def __init__(self) -> None:
super().__init__()
Expand Down Expand Up @@ -144,7 +144,7 @@ class PointSigmaMAD(PointToPointMetric):
"""

metric_name = "point_stats_sigma_mad"
metric_output_type = MetricOuputType.single_value
metric_output_type = MetricOutputType.single_value

def __init__(self) -> None:
super().__init__()
Expand Down
51 changes: 51 additions & 0 deletions tests/qp/test_metric_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# pylint: disable=no-member
# pylint: disable=protected-access

"""
Unit tests for PDF class
"""
import unittest
import numpy as np
import qp
import qp.metrics


class MetricFactoryTestCase(unittest.TestCase):
"""Tests for the metric factory"""

def setUp(self):
"""
Setup an objects that are used in multiple tests.
"""
qp.metrics.update_metrics()

def test_print_metrics(self):
"""Test printing the metrics."""
qp.metrics.print_metrics()
qp.metrics.print_metrics(force_update=True)


def test_list_metrics(self):
"""Test printing the metrics."""
the_list = qp.metrics.list_metrics()
assert the_list
the_list = qp.metrics.list_metrics(force_update=True)
assert the_list

def test_create_metrics(self):
"""Test creating all the metrics"""
all_metric_names = qp.metrics.list_metrics(force_update=True)
for metric_name in all_metric_names:
a_metric = qp.metrics.create_metric(metric_name)
assert a_metric.metric_name == metric_name
a_metric = qp.metrics.create_metric("outlier", force_update=True)
assert a_metric.metric_name == 'outlier'

def test_bad_metric_name(self):
""" Catch error on making a bad metric """
with self.assertRaises(KeyError):
qp.metrics.create_metric("Bad Metric")


if __name__ == "__main__":
unittest.main()
Loading