diff --git a/src/qp/metrics/pit.py b/src/qp/metrics/pit.py index 9f6b1f8..e6c3bd9 100644 --- a/src/qp/metrics/pit.py +++ b/src/qp/metrics/pit.py @@ -46,6 +46,14 @@ def __init__(self, qp_ens, true_vals, eval_grid=DEFAULT_QUANTS): # For each distribution in the Ensemble, calculate the CDF where x = known_true_value self._pit_samps = np.array([qp_ens[i].cdf(self._true_vals[i])[0][0] for i in range(len(self._true_vals))]) + # These two lines set all `NaN` values to 0. This may or may not make sense + # Alternatively if it's better to simply remove the `NaN`, this can be done + # efficiently on line 61 with `data_quants = np.nanquantile(...)`.` + samp_mask = np.isfinite(self._pit_samps) + self._pit_samps[~samp_mask] = 0 + if not np.all(samp_mask): + logging.warning('Some PIT samples were `NaN`. They have been replacd with 0.') + n_pit = np.min([len(self._pit_samps), len(eval_grid)]) if n_pit < len(eval_grid): logging.warning('Number of pit samples is smaller than the evaluation grid size. '