From 9782c0b6ddd2131507602c72af2b61c72928052e Mon Sep 17 00:00:00 2001 From: Drew Oldag Date: Mon, 20 Nov 2023 16:39:59 -0800 Subject: [PATCH] Adding unit tests for `PitMetric` wrapper. --- tests/qp/test_pit.py | 57 ++++++++++++++++++++++++++++++-------------- 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/tests/qp/test_pit.py b/tests/qp/test_pit.py index 5cd2ba9..4ce4e1f 100644 --- a/tests/qp/test_pit.py +++ b/tests/qp/test_pit.py @@ -1,9 +1,13 @@ +# pylint: disable=no-member +# pylint: disable=protected-access + import unittest import numpy as np import qp from qp import interp_gen from qp.ensemble import Ensemble from qp.metrics.pit import PIT +from qp.metrics.concrete_metric_classes import PITMetric # constants for tests @@ -23,15 +27,16 @@ # BIAS = -0.00001576 # SIGMAD = 0.0046489 + class PitTestCase(unittest.TestCase): - """ Test cases for PIT metric. """ + """Test cases for PIT metric.""" def setUp(self): np.random.seed(87) self.true_zs = np.random.uniform(high=NMAX, size=NPDF) locs = np.expand_dims(self.true_zs + np.random.normal(0.0, 0.01, NPDF), -1) - scales = np.ones((NPDF, 1)) * 0.1 + np.random.uniform(size=(NPDF, 1)) * .05 + scales = np.ones((NPDF, 1)) * 0.1 + np.random.uniform(size=(NPDF, 1)) * 0.05 self.n_ens = Ensemble(qp.stats.norm, data=dict(loc=locs, scale=scales)) self.grid_ens = self.n_ens.convert_to(interp_gen, xvals=ZGRID) @@ -49,36 +54,38 @@ def test_pit_metrics(self): meta_metrics = pit_obj.calculate_pit_meta_metrics() - ad_stat = meta_metrics['ad'].statistic + ad_stat = meta_metrics["ad"].statistic assert np.isclose(ad_stat, ADVAL_ALL) - cut_ad_stat = pit_obj.evaluate_PIT_anderson_ksamp(pit_min=0.6, pit_max=0.9).statistic + cut_ad_stat = pit_obj.evaluate_PIT_anderson_ksamp( + pit_min=0.6, pit_max=0.9 + ).statistic assert np.isclose(cut_ad_stat, ADVAL_CUT) - cvm_stat = meta_metrics['cvm'].statistic + cvm_stat = meta_metrics["cvm"].statistic assert np.isclose(cvm_stat, CVMVAL) - ks_stat = meta_metrics['ks'].statistic + ks_stat = meta_metrics["ks"].statistic assert np.isclose(ks_stat, KSVAL) - assert np.isclose(meta_metrics['outlier_rate'], OUTRATE) + assert np.isclose(meta_metrics["outlier_rate"], OUTRATE) def test_pit_metric_small_eval_grid(self): """Test PIT metric warning message when number of pit samples is smaller than the evaluation grid""" - with self.assertLogs(level='WARNING') as log: + with self.assertLogs(level="WARNING") as log: quant_grid = np.linspace(0, 1, 1000) _ = PIT(self.grid_ens, self.true_zs, quant_grid) - self.assertIn('Number of pit samples is smaller', log.output[0]) + self.assertIn("Number of pit samples is smaller", log.output[0]) def test_pit_metric_masking(self): """The normal distributions created in this test will produce a quantile array in PIT that have multiple values of 1.0. This test will confirm - that the some quants have been removed. + that the some quants have been removed. If no quants had been removed, then the final length would be 101. """ - true_zs = np.random.uniform(low=NMAX-0.1, high=NMAX, size=NPDF) + true_zs = np.random.uniform(low=NMAX - 0.1, high=NMAX, size=NPDF) locs = np.expand_dims(true_zs + np.random.normal(0.0, 0.01, NPDF), -1) scales = np.ones((NPDF, 1)) * 0.001 @@ -93,13 +100,27 @@ def test_pit_metric_masking(self): def test_pit_create_quant_mask(self): """Basic test where all values should be returned""" - input = np.linspace(0.1, 0.9, 10) - mask = PIT._create_quant_mask(input) - assert np.all(input[mask] == input) + input_grid = np.linspace(0.1, 0.9, 10) + mask = PIT._create_quant_mask(input_grid) + assert np.all(input_grid[mask] == input_grid) def test_pit_create_quant_mask_with_exclusions(self): """Test with values that should be excluded""" - input = np.linspace(-0.1, 1.1, 10) - mask = PIT._create_quant_mask(input) - assert np.all(input[mask] > 0) - assert np.all(input[mask] < 1) \ No newline at end of file + input_grid = np.linspace(-0.1, 1.1, 10) + mask = PIT._create_quant_mask(input_grid) + assert np.all(input_grid[mask] > 0) + assert np.all(input_grid[mask] < 1) + + def test_pit_metric_class_matches_original_pit_class(self): + """Compare base test of PIT metric generation with the metric class wrapped + version of PIT. We compare the values of the PDFs directly.""" + quant_grid = np.linspace(0, 1, 101) + pit_obj = PIT(self.grid_ens, self.true_zs, quant_grid) + + pit_metric = PITMetric(eval_grid=quant_grid) + pit_metric.initialize() + class_result = pit_metric.evaluate(self.grid_ens, self.true_zs) + pit_metric.finalize() + + eval_grid = np.linspace(0, 3, 100) + assert np.all(class_result.pdf(eval_grid) == pit_obj.pit.pdf(eval_grid))