Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue/166/generalize mixmod pdf #199

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/qp/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ def _build_data_dict(md_table, data_table):
data_dict[col] = col_data
return data_dict

def _make_scipy_wrapped_class(self, class_name, scipy_class):
def _make_scipy_wrapped_class(self, class_name, scipy_class, ctor_param):
"""Build a qp class from a scipy class"""
# pylint: disable=protected-access
override_dict = dict(
name=class_name,
version=0,
freeze=Pdf_gen_wrap._my_freeze,
_other_init=scipy_class.__init__,
_ctor_param=ctor_param,
)
the_class = type(class_name, (Pdf_gen_wrap, scipy_class), override_dict)
self.add_class(the_class)
Expand All @@ -72,7 +73,7 @@ def _load_scipy_classes(self):
for name in names:
attr = getattr(sps, name)
if isinstance(attr, sps.rv_continuous):
self._make_scipy_wrapped_class(name, type(attr))
self._make_scipy_wrapped_class(name, type(attr), attr._updated_ctor_param())

def add_class(self, the_class):
"""Add a class to the factory
Expand Down
94 changes: 44 additions & 50 deletions src/qp/mixmod_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,61 @@ class mixmod_gen(Pdf_rows_gen):

Notes
-----
This implements a PDF using a Gaussian Mixture model
This is a base class for implementing PDFs using a mixture model.
Classes implementing mixture models with specific basis functions
need to define the pdf and cdf.

The relevant data members are:

means: (npdf, ncomp) means of the Gaussians
stds: (npdf, ncomp) standard deviations of the Gaussians
weights: (npdf, ncomp) weights for the Gaussians
means: (npdf, ncomp) means of the basis functions
stds: (npdf, ncomp) standard deviations of the basis functions
weights: (npdf, ncomp) weights for the basis functions

The pdf() and cdf() are exact, and are computed as a weighted sum of
the pdf() and cdf() of the component Gaussians.
the pdf() and cdf() of the component basis functions.

The ppf() is computed by computing the cdf() values on a fixed
grid and interpolating the inverse function.
"""

# pylint: disable=protected-access

name = "mixmod"
version = 0

_support_mask = rv_continuous._support_mask

def __init__(self, means, stds, weights, *args, **kwargs):
name = 'mixmod'
version = 0

def __init__(self, gen_func, weights, data, ancil=None, *args, **kwargs):
"""
Create a new distribution using the given histogram

Parameters
----------
means : array_like
The means of the Gaussians
The means of the basis functions
stds: array_like
The standard deviations of the Gaussians
The standard deviations of the basis functions
weights : array_like
The weights to attach to the Gaussians. Weights should sum up to one.
If not, the weights are interpreted as relative weights.
"""
self._gen_func = gen_func
self._frozen = self._gen_func(**data)
self._gen_obj = self._frozen.dist
self._gen_class = type(self._gen_obj)
self._data = data

self._scipy_version_warning()
self._means = reshape_to_pdf_size(means, -1)
self._stds = reshape_to_pdf_size(stds, -1)
self._weights = reshape_to_pdf_size(weights, -1)
kwargs["shape"] = means.shape[:-1]
self._ncomps = means.shape[-1]
for key in self._data.keys():
self._data[key] = reshape_to_pdf_size(self._data[key],-1)
kwargs['shape'] = weights.shape[:-1]
self._ncomps = weights.shape[-1]
super().__init__(*args, **kwargs)
if np.any(self._weights < 0):
raise ValueError("All weights need to be larger than zero")
self._weights = self._weights / self._weights.sum(axis=1)[:, None]
self._addobjdata("weights", self._weights)
self._addobjdata("stds", self._stds)
self._addobjdata("means", self._means)
if np.any(self._weights<0):
raise ValueError('All weights need to be larger than zero')
self._weights = self._weights/self._weights.sum(axis=1)[:,None]
self._addobjdata('weights', self._weights)

def _scipy_version_warning(self):
import scipy # pylint: disable=import-outside-toplevel
Expand All @@ -80,42 +86,28 @@ def _scipy_version_warning(self):

@property
def weights(self):
"""Return weights to attach to the Gaussians"""
"""Return weights to attach to the basis functions"""
return self._weights

@property
def means(self):
"""Return means of the Gaussians"""
return self._means

@property
def stds(self):
"""Return standard deviations of the Gaussians"""
return self._stds

def _pdf(self, x, row):
# pylint: disable=arguments-differ
if np.ndim(x) > 1: # pragma: no cover
x = np.expand_dims(x, -2)
return (
self.weights[row].swapaxes(-2, -1)
* sps.norm(
loc=self._means[row].swapaxes(-2, -1),
scale=self._stds[row].swapaxes(-2, -1),
).pdf(x)
).sum(axis=0)
data_swap=dict()
for key in self._data.keys():
data_swap[key] = self._data[key][row].swapaxes(-2,-1)
return (self.weights[row].swapaxes(-2,-1) *
self._gen_func(**data_swap).pdf(x)).sum(axis=0)

def _cdf(self, x, row):
# pylint: disable=arguments-differ
if np.ndim(x) > 1: # pragma: no cover
x = np.expand_dims(x, -2)
return (
self.weights[row].swapaxes(-2, -1)
* sps.norm(
loc=self._means[row].swapaxes(-2, -1),
scale=self._stds[row].swapaxes(-2, -1),
).cdf(x)
).sum(axis=0)
data_swap=dict()
for key in self._data.keys():
data_swap[key] = self._data[key][row].swapaxes(-2,-1)
return (self.weights[row].swapaxes(-2,-1) *
self._gen_func(**data_swap).cdf(x)).sum(axis=0)

def _ppf(self, x, row):
# pylint: disable=arguments-differ
Expand All @@ -140,9 +132,11 @@ def _updated_ctor_param(self):
Set the bins as additional constructor argument
"""
dct = super()._updated_ctor_param()
dct["means"] = self._means
dct["stds"] = self._stds
dct["weights"] = self._weights
# for key in self._data.keys():
# dct[key] = self._data[key]
dct['weights'] = self._weights
dct['data'] = self._data
dct['gen_func'] = self._gen_func
return dct

@classmethod
Expand Down Expand Up @@ -191,7 +185,7 @@ def make_test_data(cls):
)
)


mixmod = mixmod_gen.create

add_class(mixmod_gen)

5 changes: 4 additions & 1 deletion src/qp/pdf_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,10 @@ def __init__(self, *args, **kwargs):
"""C'tor"""
# pylint: disable=no-member,protected-access
super().__init__(*args, **kwargs)
self._other_init(*args, **kwargs)
if kwargs==self._ctor_param:
kwargs=dict()
kwargs.pop('name', None)
self._other_init(*args, **kwargs, **self._ctor_param)

def _my_freeze(self, *args, **kwds):
"""Freeze the distribution for the given arguments.
Expand Down
3 changes: 2 additions & 1 deletion tests/qp/test_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def auto_add(cls, class_list, ens_orig):
ENS_MULTI = test_funcs.build_ensemble(
qp.stats.norm_gen.test_data["norm"] # pylint: disable=no-member
)
TEST_CLASSES = qp.instance().values()
TEST_CLASSES = list(qp.instance().values())
TEST_CLASSES.remove(qp.mixmod_pdf.mixmod_gen)

PDFTestCase.auto_add(TEST_CLASSES, [ENS_ORIG, ENS_MULTI])

Expand Down
9 changes: 8 additions & 1 deletion tests/qp/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ def test_mixmod_with_negative_weights(self):
with self.assertRaises(ValueError):
_ = qp.mixmod(weights=weights, means=means, stds=sigmas)

def test_mixmod_with_negative_weights(self):
"""Verify that an exception is raised when setting up a mixture model with negative weights"""
means = np.array([0.5,1.1, 2.9])
sigmas = np.array([0.15,0.13,0.14])
weights = np.array([1,0.5,-0.25])
with self.assertRaises(ValueError):
_ = qp.mixmod(gen_func=qp.stats.norm, weights=weights, data = dict(loc=means, scale=sigmas))

if __name__ == "__main__":
if __name__ == '__main__':
unittest.main()
Loading