Skip to content

Commit

Permalink
fixed mixture model implementation with scipy basis functions that re…
Browse files Browse the repository at this point in the history
…quire additional parameters
  • Loading branch information
Benjamin Stölzner authored and eacharles committed Nov 29, 2023
1 parent 6a2952e commit 6b51f11
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/qp/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _build_data_dict(md_table, data_table):
data_dict[col] = col_data
return data_dict

def _make_scipy_wrapped_class(self, class_name, scipy_class):
def _make_scipy_wrapped_class(self, class_name, scipy_class, ctor_param):
"""Build a qp class from a scipy class"""
# pylint: disable=protected-access
override_dict = dict(
Expand All @@ -72,7 +72,7 @@ def _load_scipy_classes(self):
for name in names:
attr = getattr(sps, name)
if isinstance(attr, sps.rv_continuous):
self._make_scipy_wrapped_class(name, type(attr))
self._make_scipy_wrapped_class(name, type(attr), attr._updated_ctor_param())

def add_class(self, the_class):
"""Add a class to the factory
Expand Down
2 changes: 0 additions & 2 deletions src/qp/mixmod_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def __init__(self, gen_func, weights, data, ancil=None, *args, **kwargs):
>>>>>>> 9b08a50 (Started implementing mixture model implementation with generic scipy base function. So far only the ones using loc and scale parameters work. PPFs need to be implemented.)
"""
self._gen_func = gen_func
print(data)
self._frozen = self._gen_func(**data)
self._gen_obj = self._frozen.dist
self._gen_class = type(self._gen_obj)
Expand All @@ -87,7 +86,6 @@ def __init__(self, gen_func, weights, data, ancil=None, *args, **kwargs):
=======
for key in self._data.keys():
self._data[key] = reshape_to_pdf_size(self._data[key],-1)
print(self._data)
kwargs['shape'] = weights.shape[:-1]
self._ncomps = weights.shape[-1]
super().__init__(*args, **kwargs)
Expand Down
9 changes: 8 additions & 1 deletion src/qp/pdf_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,13 @@ def create_gen(cls, **kwds):
"""Create and return a `scipy.stats.rv_continuous` object using the
keyword arguemntets provided"""
kwds_copy = kwds.copy()
<<<<<<< HEAD
name = kwds_copy.pop("name", "dist")
return (cls(name=name), kwds_copy)
=======
name = kwds_copy.pop('name', 'dist')
return (cls(), kwds_copy)
>>>>>>> 71047e2 (fixed mixture model implementation with scipy basis functions that require additional parameters)

@classmethod
def create(cls, **kwds):
Expand Down Expand Up @@ -393,7 +398,9 @@ def __init__(self, *args, **kwargs):
"""C'tor"""
# pylint: disable=no-member,protected-access
super().__init__(*args, **kwargs)
self._other_init(*args, **kwargs)
if kwargs==self._ctor_param:
kwargs=dict()
self._other_init(*args, **kwargs, **self._ctor_param)

def _my_freeze(self, *args, **kwds):
"""Freeze the distribution for the given arguments.
Expand Down

0 comments on commit 6b51f11

Please sign in to comment.