Skip to content

Commit

Permalink
Adding KLD wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag committed Nov 8, 2023
1 parent 441df77 commit 49219e3
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/qp/metrics/base_metric_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,26 @@ class MetricInputType(enum.Enum):
point_to_dist = 4


def uses_distribution_for_estimate(self):
def uses_distribution_for_estimate(self) -> bool:
return self in [
MetricInputType.single_ensemble,
MetricInputType.dist_to_point,
MetricInputType.dist_to_dist,
]

def uses_distribution_for_reference(self):
def uses_distribution_for_reference(self) -> bool:
return self in [
MetricInputType.dist_to_dist,
MetricInputType.point_to_dist,
]

def uses_point_for_estimate(self):
def uses_point_for_estimate(self) -> bool:
return self in [
MetricInputType.point_to_dist,
MetricInputType.point_to_dist,
]

def uses_point_for_reference(self):
def uses_point_for_reference(self) -> bool:
return self in [
MetricInputType.dist_to_point,
MetricInputType.point_to_point,
Expand Down Expand Up @@ -80,29 +80,20 @@ def __init__(self, limit:tuple=(0.0, 3.0), dx:float=0.01) -> None:
self._limit = limit
self._dx = dx

def initialize(self) -> None:
pass

def evaluate(self) -> None:
pass

def finalize(self) -> None:
pass

@classmethod
def uses_distribution_for_estimate(cls):
def uses_distribution_for_estimate(cls) -> bool:
return cls.metric_input_type.uses_distribution_for_estimate()

@classmethod
def uses_distribution_for_reference(cls):
def uses_distribution_for_reference(cls) -> bool:
return cls.metric_input_type.uses_distribution_for_reference()

@classmethod
def uses_point_for_estimate(cls):
def uses_point_for_estimate(cls) -> bool:
return cls.metric_input_type.uses_point_for_estimate()

@classmethod
def uses_point_for_reference(cls):
def uses_point_for_reference(cls) -> bool:
return cls.metric_input_type.uses_point_for_reference()


Expand All @@ -121,6 +112,15 @@ class DistToPointMetric(BaseMetric):

metric_input_type = MetricInputType.dist_to_point

def initialize(self, **kwargs):
raise NotImplementedError()

def evaluate(self, estimate, reference, **kwargs):
raise NotImplementedError()

def finalize(self, **kwargs):
raise NotImplementedError()


class PointToPointMetric(BaseMetric):

Expand Down
23 changes: 23 additions & 0 deletions src/qp/metrics/concrete_metric_classes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from qp.metrics.base_metric_classes import MetricOuputType, DistToDistMetric
from qp.metrics.metrics import calculate_kld

class KLDMetric(DistToDistMetric):
"""Class wrapper around the KLD metric
"""

metric_name = "kld"
metric_output_type = MetricOuputType.one_value_per_distribution

def __init__(self, limits:tuple=(0.0, 3.0), dx:float=0.01) -> None:
super().__init__()
self._limits = limits
self._dx = dx

def initialize(self) -> None:
pass

def evaluate(self, estimate, reference) -> None:
return calculate_kld(estimate, reference, self._limits, self._dx)

def finalize(self) -> None:
pass

0 comments on commit 49219e3

Please sign in to comment.