diff --git a/src/qp/factory.py b/src/qp/factory.py index a8ef2c0..9f4b581 100644 --- a/src/qp/factory.py +++ b/src/qp/factory.py @@ -54,7 +54,7 @@ def _build_data_dict(md_table, data_table): data_dict[col] = col_data return data_dict - def _make_scipy_wrapped_class(self, class_name, scipy_class): + def _make_scipy_wrapped_class(self, class_name, scipy_class, ctor_param): """Build a qp class from a scipy class""" # pylint: disable=protected-access override_dict = dict( @@ -62,6 +62,7 @@ def _make_scipy_wrapped_class(self, class_name, scipy_class): version=0, freeze=Pdf_gen_wrap._my_freeze, _other_init=scipy_class.__init__, + _ctor_param=ctor_param, ) the_class = type(class_name, (Pdf_gen_wrap, scipy_class), override_dict) self.add_class(the_class) @@ -72,7 +73,7 @@ def _load_scipy_classes(self): for name in names: attr = getattr(sps, name) if isinstance(attr, sps.rv_continuous): - self._make_scipy_wrapped_class(name, type(attr)) + self._make_scipy_wrapped_class(name, type(attr), attr._updated_ctor_param()) def add_class(self, the_class): """Add a class to the factory diff --git a/src/qp/mixmod_pdf.py b/src/qp/mixmod_pdf.py index e9cd405..649ca66 100644 --- a/src/qp/mixmod_pdf.py +++ b/src/qp/mixmod_pdf.py @@ -17,16 +17,18 @@ class mixmod_gen(Pdf_rows_gen): Notes ----- - This implements a PDF using a Gaussian Mixture model + This is a base class for implementing PDFs using a mixture model. + Classes implementing mixture models with specific basis functions + need to define the pdf and cdf. The relevant data members are: - means: (npdf, ncomp) means of the Gaussians - stds: (npdf, ncomp) standard deviations of the Gaussians - weights: (npdf, ncomp) weights for the Gaussians + means: (npdf, ncomp) means of the basis functions + stds: (npdf, ncomp) standard deviations of the basis functions + weights: (npdf, ncomp) weights for the basis functions The pdf() and cdf() are exact, and are computed as a weighted sum of - the pdf() and cdf() of the component Gaussians. + the pdf() and cdf() of the component basis functions. The ppf() is computed by computing the cdf() values on a fixed grid and interpolating the inverse function. @@ -34,38 +36,42 @@ class mixmod_gen(Pdf_rows_gen): # pylint: disable=protected-access - name = "mixmod" - version = 0 - _support_mask = rv_continuous._support_mask - def __init__(self, means, stds, weights, *args, **kwargs): + name = 'mixmod' + version = 0 + + def __init__(self, gen_func, weights, data, ancil=None, *args, **kwargs): """ Create a new distribution using the given histogram Parameters ---------- means : array_like - The means of the Gaussians + The means of the basis functions stds: array_like - The standard deviations of the Gaussians + The standard deviations of the basis functions weights : array_like The weights to attach to the Gaussians. Weights should sum up to one. If not, the weights are interpreted as relative weights. """ + self._gen_func = gen_func + self._frozen = self._gen_func(**data) + self._gen_obj = self._frozen.dist + self._gen_class = type(self._gen_obj) + self._data = data + self._scipy_version_warning() - self._means = reshape_to_pdf_size(means, -1) - self._stds = reshape_to_pdf_size(stds, -1) self._weights = reshape_to_pdf_size(weights, -1) - kwargs["shape"] = means.shape[:-1] - self._ncomps = means.shape[-1] + for key in self._data.keys(): + self._data[key] = reshape_to_pdf_size(self._data[key],-1) + kwargs['shape'] = weights.shape[:-1] + self._ncomps = weights.shape[-1] super().__init__(*args, **kwargs) - if np.any(self._weights < 0): - raise ValueError("All weights need to be larger than zero") - self._weights = self._weights / self._weights.sum(axis=1)[:, None] - self._addobjdata("weights", self._weights) - self._addobjdata("stds", self._stds) - self._addobjdata("means", self._means) + if np.any(self._weights<0): + raise ValueError('All weights need to be larger than zero') + self._weights = self._weights/self._weights.sum(axis=1)[:,None] + self._addobjdata('weights', self._weights) def _scipy_version_warning(self): import scipy # pylint: disable=import-outside-toplevel @@ -80,42 +86,28 @@ def _scipy_version_warning(self): @property def weights(self): - """Return weights to attach to the Gaussians""" + """Return weights to attach to the basis functions""" return self._weights - @property - def means(self): - """Return means of the Gaussians""" - return self._means - - @property - def stds(self): - """Return standard deviations of the Gaussians""" - return self._stds - def _pdf(self, x, row): # pylint: disable=arguments-differ if np.ndim(x) > 1: # pragma: no cover x = np.expand_dims(x, -2) - return ( - self.weights[row].swapaxes(-2, -1) - * sps.norm( - loc=self._means[row].swapaxes(-2, -1), - scale=self._stds[row].swapaxes(-2, -1), - ).pdf(x) - ).sum(axis=0) + data_swap=dict() + for key in self._data.keys(): + data_swap[key] = self._data[key][row].swapaxes(-2,-1) + return (self.weights[row].swapaxes(-2,-1) * + self._gen_func(**data_swap).pdf(x)).sum(axis=0) def _cdf(self, x, row): # pylint: disable=arguments-differ if np.ndim(x) > 1: # pragma: no cover x = np.expand_dims(x, -2) - return ( - self.weights[row].swapaxes(-2, -1) - * sps.norm( - loc=self._means[row].swapaxes(-2, -1), - scale=self._stds[row].swapaxes(-2, -1), - ).cdf(x) - ).sum(axis=0) + data_swap=dict() + for key in self._data.keys(): + data_swap[key] = self._data[key][row].swapaxes(-2,-1) + return (self.weights[row].swapaxes(-2,-1) * + self._gen_func(**data_swap).cdf(x)).sum(axis=0) def _ppf(self, x, row): # pylint: disable=arguments-differ @@ -140,9 +132,11 @@ def _updated_ctor_param(self): Set the bins as additional constructor argument """ dct = super()._updated_ctor_param() - dct["means"] = self._means - dct["stds"] = self._stds - dct["weights"] = self._weights + # for key in self._data.keys(): + # dct[key] = self._data[key] + dct['weights'] = self._weights + dct['data'] = self._data + dct['gen_func'] = self._gen_func return dct @classmethod @@ -191,7 +185,7 @@ def make_test_data(cls): ) ) - mixmod = mixmod_gen.create add_class(mixmod_gen) + diff --git a/src/qp/pdf_gen.py b/src/qp/pdf_gen.py index 36fefac..f2052eb 100644 --- a/src/qp/pdf_gen.py +++ b/src/qp/pdf_gen.py @@ -393,7 +393,10 @@ def __init__(self, *args, **kwargs): """C'tor""" # pylint: disable=no-member,protected-access super().__init__(*args, **kwargs) - self._other_init(*args, **kwargs) + if kwargs==self._ctor_param: + kwargs=dict() + kwargs.pop('name', None) + self._other_init(*args, **kwargs, **self._ctor_param) def _my_freeze(self, *args, **kwds): """Freeze the distribution for the given arguments. diff --git a/tests/qp/test_auto.py b/tests/qp/test_auto.py index c4aedc0..ce41dff 100644 --- a/tests/qp/test_auto.py +++ b/tests/qp/test_auto.py @@ -80,7 +80,8 @@ def auto_add(cls, class_list, ens_orig): ENS_MULTI = test_funcs.build_ensemble( qp.stats.norm_gen.test_data["norm"] # pylint: disable=no-member ) -TEST_CLASSES = qp.instance().values() +TEST_CLASSES = list(qp.instance().values()) +TEST_CLASSES.remove(qp.mixmod_pdf.mixmod_gen) PDFTestCase.auto_add(TEST_CLASSES, [ENS_ORIG, ENS_MULTI]) diff --git a/tests/qp/test_ensemble.py b/tests/qp/test_ensemble.py index 46a0cbe..6fc5951 100644 --- a/tests/qp/test_ensemble.py +++ b/tests/qp/test_ensemble.py @@ -235,6 +235,13 @@ def test_mixmod_with_negative_weights(self): with self.assertRaises(ValueError): _ = qp.mixmod(weights=weights, means=means, stds=sigmas) + def test_mixmod_with_negative_weights(self): + """Verify that an exception is raised when setting up a mixture model with negative weights""" + means = np.array([0.5,1.1, 2.9]) + sigmas = np.array([0.15,0.13,0.14]) + weights = np.array([1,0.5,-0.25]) + with self.assertRaises(ValueError): + _ = qp.mixmod(gen_func=qp.stats.norm, weights=weights, data = dict(loc=means, scale=sigmas)) -if __name__ == "__main__": +if __name__ == '__main__': unittest.main()