Skip to content

Commit

Permalink
Adding unit tests for PitMetric wrapper.
Browse files Browse the repository at this point in the history
  • Loading branch information
drewoldag committed Nov 21, 2023
1 parent f1e9f56 commit 9782c0b
Showing 1 changed file with 39 additions and 18 deletions.
57 changes: 39 additions & 18 deletions tests/qp/test_pit.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# pylint: disable=no-member
# pylint: disable=protected-access

import unittest
import numpy as np
import qp
from qp import interp_gen
from qp.ensemble import Ensemble
from qp.metrics.pit import PIT
from qp.metrics.concrete_metric_classes import PITMetric


# constants for tests
Expand All @@ -23,15 +27,16 @@
# BIAS = -0.00001576
# SIGMAD = 0.0046489


class PitTestCase(unittest.TestCase):
""" Test cases for PIT metric. """
"""Test cases for PIT metric."""

def setUp(self):
np.random.seed(87)
self.true_zs = np.random.uniform(high=NMAX, size=NPDF)

locs = np.expand_dims(self.true_zs + np.random.normal(0.0, 0.01, NPDF), -1)
scales = np.ones((NPDF, 1)) * 0.1 + np.random.uniform(size=(NPDF, 1)) * .05
scales = np.ones((NPDF, 1)) * 0.1 + np.random.uniform(size=(NPDF, 1)) * 0.05
self.n_ens = Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales))

self.grid_ens = self.n_ens.convert_to(interp_gen, xvals=ZGRID)
Expand All @@ -49,36 +54,38 @@ def test_pit_metrics(self):

meta_metrics = pit_obj.calculate_pit_meta_metrics()

ad_stat = meta_metrics['ad'].statistic
ad_stat = meta_metrics["ad"].statistic
assert np.isclose(ad_stat, ADVAL_ALL)

cut_ad_stat = pit_obj.evaluate_PIT_anderson_ksamp(pit_min=0.6, pit_max=0.9).statistic
cut_ad_stat = pit_obj.evaluate_PIT_anderson_ksamp(
pit_min=0.6, pit_max=0.9
).statistic
assert np.isclose(cut_ad_stat, ADVAL_CUT)

cvm_stat = meta_metrics['cvm'].statistic
cvm_stat = meta_metrics["cvm"].statistic
assert np.isclose(cvm_stat, CVMVAL)

ks_stat = meta_metrics['ks'].statistic
ks_stat = meta_metrics["ks"].statistic
assert np.isclose(ks_stat, KSVAL)

assert np.isclose(meta_metrics['outlier_rate'], OUTRATE)
assert np.isclose(meta_metrics["outlier_rate"], OUTRATE)

def test_pit_metric_small_eval_grid(self):
"""Test PIT metric warning message when number of pit samples is smaller than the evaluation grid"""
with self.assertLogs(level='WARNING') as log:
with self.assertLogs(level="WARNING") as log:
quant_grid = np.linspace(0, 1, 1000)
_ = PIT(self.grid_ens, self.true_zs, quant_grid)
self.assertIn('Number of pit samples is smaller', log.output[0])
self.assertIn("Number of pit samples is smaller", log.output[0])

def test_pit_metric_masking(self):
"""The normal distributions created in this test will produce a quantile
array in PIT that have multiple values of 1.0. This test will confirm
that the some quants have been removed.
that the some quants have been removed.
If no quants had been removed, then the final length would be 101.
"""

true_zs = np.random.uniform(low=NMAX-0.1, high=NMAX, size=NPDF)
true_zs = np.random.uniform(low=NMAX - 0.1, high=NMAX, size=NPDF)

locs = np.expand_dims(true_zs + np.random.normal(0.0, 0.01, NPDF), -1)
scales = np.ones((NPDF, 1)) * 0.001
Expand All @@ -93,13 +100,27 @@ def test_pit_metric_masking(self):

def test_pit_create_quant_mask(self):
"""Basic test where all values should be returned"""
input = np.linspace(0.1, 0.9, 10)
mask = PIT._create_quant_mask(input)
assert np.all(input[mask] == input)
input_grid = np.linspace(0.1, 0.9, 10)
mask = PIT._create_quant_mask(input_grid)
assert np.all(input_grid[mask] == input_grid)

def test_pit_create_quant_mask_with_exclusions(self):
"""Test with values that should be excluded"""
input = np.linspace(-0.1, 1.1, 10)
mask = PIT._create_quant_mask(input)
assert np.all(input[mask] > 0)
assert np.all(input[mask] < 1)
input_grid = np.linspace(-0.1, 1.1, 10)
mask = PIT._create_quant_mask(input_grid)
assert np.all(input_grid[mask] > 0)
assert np.all(input_grid[mask] < 1)

def test_pit_metric_class_matches_original_pit_class(self):
"""Compare base test of PIT metric generation with the metric class wrapped
version of PIT. We compare the values of the PDFs directly."""
quant_grid = np.linspace(0, 1, 101)
pit_obj = PIT(self.grid_ens, self.true_zs, quant_grid)

pit_metric = PITMetric(eval_grid=quant_grid)
pit_metric.initialize()
class_result = pit_metric.evaluate(self.grid_ens, self.true_zs)
pit_metric.finalize()

eval_grid = np.linspace(0, 3, 100)
assert np.all(class_result.pdf(eval_grid) == pit_obj.pit.pdf(eval_grid))

0 comments on commit 9782c0b

Please sign in to comment.