diff --git a/src/qp/metrics/base_metric_classes.py b/src/qp/metrics/base_metric_classes.py index 3077db7..a49f5e2 100644 --- a/src/qp/metrics/base_metric_classes.py +++ b/src/qp/metrics/base_metric_classes.py @@ -28,26 +28,26 @@ class MetricInputType(enum.Enum): point_to_dist = 4 - def uses_distribution_for_estimate(self): + def uses_distribution_for_estimate(self) -> bool: return self in [ MetricInputType.single_ensemble, MetricInputType.dist_to_point, MetricInputType.dist_to_dist, ] - def uses_distribution_for_reference(self): + def uses_distribution_for_reference(self) -> bool: return self in [ MetricInputType.dist_to_dist, MetricInputType.point_to_dist, ] - def uses_point_for_estimate(self): + def uses_point_for_estimate(self) -> bool: return self in [ MetricInputType.point_to_dist, MetricInputType.point_to_dist, ] - def uses_point_for_reference(self): + def uses_point_for_reference(self) -> bool: return self in [ MetricInputType.dist_to_point, MetricInputType.point_to_point, @@ -80,29 +80,20 @@ def __init__(self, limit:tuple=(0.0, 3.0), dx:float=0.01) -> None: self._limit = limit self._dx = dx - def initialize(self) -> None: - pass - - def evaluate(self) -> None: - pass - - def finalize(self) -> None: - pass - @classmethod - def uses_distribution_for_estimate(cls): + def uses_distribution_for_estimate(cls) -> bool: return cls.metric_input_type.uses_distribution_for_estimate() @classmethod - def uses_distribution_for_reference(cls): + def uses_distribution_for_reference(cls) -> bool: return cls.metric_input_type.uses_distribution_for_reference() @classmethod - def uses_point_for_estimate(cls): + def uses_point_for_estimate(cls) -> bool: return cls.metric_input_type.uses_point_for_estimate() @classmethod - def uses_point_for_reference(cls): + def uses_point_for_reference(cls) -> bool: return cls.metric_input_type.uses_point_for_reference() @@ -121,6 +112,15 @@ class DistToPointMetric(BaseMetric): metric_input_type = MetricInputType.dist_to_point + def initialize(self, **kwargs): + raise NotImplementedError() + + def evaluate(self, estimate, reference, **kwargs): + raise NotImplementedError() + + def finalize(self, **kwargs): + raise NotImplementedError() + class PointToPointMetric(BaseMetric): diff --git a/src/qp/metrics/concrete_metric_classes.py b/src/qp/metrics/concrete_metric_classes.py new file mode 100644 index 0000000..8db3fc4 --- /dev/null +++ b/src/qp/metrics/concrete_metric_classes.py @@ -0,0 +1,23 @@ +from qp.metrics.base_metric_classes import MetricOuputType, DistToDistMetric +from qp.metrics.metrics import calculate_kld + +class KLDMetric(DistToDistMetric): + """Class wrapper around the KLD metric + """ + + metric_name = "kld" + metric_output_type = MetricOuputType.one_value_per_distribution + + def __init__(self, limits:tuple=(0.0, 3.0), dx:float=0.01) -> None: + super().__init__() + self._limits = limits + self._dx = dx + + def initialize(self) -> None: + pass + + def evaluate(self, estimate, reference) -> None: + return calculate_kld(estimate, reference, self._limits, self._dx) + + def finalize(self) -> None: + pass