Skip to content

Commit

Permalink
Addressing PR comments. Consolidated duplicated method definitions. U…
Browse files Browse the repository at this point in the history
…pdated tests.
  • Loading branch information
drewoldag committed Nov 21, 2023
1 parent 9782c0b commit 838ba13
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 151 deletions.
71 changes: 33 additions & 38 deletions src/qp/metrics/base_metric_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from abc import ABC


class MetricInputType(enum.Enum):
"""Defines the various combinations of input types that metric classes accept.
"""
"""Defines the various combinations of input types that metric classes accept."""

unknown = -1

Expand All @@ -27,7 +27,6 @@ class MetricInputType(enum.Enum):
# distribution, or collection of distributions for reference(s).
point_to_dist = 4


def uses_distribution_for_estimate(self) -> bool:
return self in [
MetricInputType.single_ensemble,
Expand Down Expand Up @@ -55,8 +54,7 @@ def uses_point_for_reference(self) -> bool:


class MetricOuputType(enum.Enum):
"""Defines the various output types that metric classes can return.
"""
"""Defines the various output types that metric classes can return."""

unknown = -1

Expand All @@ -71,15 +69,28 @@ class MetricOuputType(enum.Enum):


class BaseMetric(ABC):
metric_name = None # The name for this metric, overwritten in subclasses
metric_input_type = MetricInputType.unknown # The type of input data expected for this metric
metric_output_type = MetricOuputType.unknown # The form of the output data from this metric
"""This is the base class for all of the qp metrics. It establishes the most
of the basic API for a consistent interaction with the metrics qp provides.
"""

def __init__(self, limits:tuple=(0.0, 3.0), dx:float=0.01) -> None:
metric_name = None # The name for this metric, overwritten in subclasses
metric_input_type = (
MetricInputType.unknown
) # The type of input data expected for this metric
metric_output_type = (
MetricOuputType.unknown
) # The form of the output data from this metric

def __init__(self, limits: tuple = (0.0, 3.0), dx: float = 0.01) -> None:
self._limits = limits
self._dx = dx

def initialize(self):
pass

def finalize(self):
pass

@classmethod
def uses_distribution_for_estimate(cls) -> bool:
return cls.metric_input_type.uses_distribution_for_estimate()
Expand All @@ -98,70 +109,54 @@ def uses_point_for_reference(cls) -> bool:


class SingleEnsembleMetric(BaseMetric):
"""A base class for metrics that accept only a single ensemble as input."""

metric_input_type = MetricInputType.single_ensemble
metric_output_type = MetricOuputType.one_value_per_distribution

def initialize(self):
raise NotImplementedError()

def evaluate(self, estimate):
raise NotImplementedError()

def finalize(self):
raise NotImplementedError()


class DistToDistMetric(BaseMetric):
"""A base class for metrics that requires distributions as input for both the
estimated and reference values.
"""

metric_input_type = MetricInputType.dist_to_dist

def initialize(self):
raise NotImplementedError()

def evaluate(self, estimate, reference):
raise NotImplementedError()

def finalize(self):
raise NotImplementedError()


class DistToPointMetric(BaseMetric):
"""A base class for metrics that require a distribution as the estimated
value and a point estimate as the reference value.
"""

metric_input_type = MetricInputType.dist_to_point

def initialize(self):
raise NotImplementedError()

def evaluate(self, estimate, reference):
raise NotImplementedError()

def finalize(self):
raise NotImplementedError()


class PointToPointMetric(BaseMetric):
"""A base class for metrics that require a point estimate as input for both
the estimated and reference values.
"""

metric_input_type = MetricInputType.point_to_point

def initialize(self):
raise NotImplementedError()

def evaluate(self, estimate, reference):
raise NotImplementedError()

def finalize(self):
raise NotImplementedError()

class PointToDistMetric(BaseMetric):
"""A base class for metrics that require a point estimate as the estimated
value and a distribution as the reference value.
"""

metric_input_type = MetricInputType.point_to_dist

def initialize(self):
raise NotImplementedError()

def evaluate(self, estimate, reference):
raise NotImplementedError()

def finalize(self):
raise NotImplementedError()
Loading

0 comments on commit 838ba13

Please sign in to comment.