Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added generalized gamma distribution #59

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ methods in the current release of PyMC experimental.
=============================

.. automodule:: pymc_experimental.distributions.histogram_utils
:members: histogram_approximation
:members: histogram_approximation, GeneralizedGamma


:mod:`pymc_experimental.utils`
Expand Down
3 changes: 3 additions & 0 deletions pymc_experimental/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@

from pymc_experimental import distributions, gp, utils
from pymc_experimental.bart import *
from pymc_experimental.distributions import (
GeneralizedGamma
)
3 changes: 3 additions & 0 deletions pymc_experimental/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from pymc_experimental.distributions import histogram_utils
from pymc_experimental.distributions.histogram_utils import histogram_approximation
from pymc_experimental.distributions.continuous import (
GeneralizedGamma,
)
175 changes: 175 additions & 0 deletions pymc_experimental/distributions/continuous.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
from typing import List, Optional, Tuple, Union

import numpy as np
import aesara
import aesara.tensor as at
from pymc.aesaraf import floatX
from aesara.tensor.var import TensorConstant, TensorVariable
from aesara.tensor.random.op import RandomVariable

from aesara.tensor.random.basic import gengamma
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
from pymc.distributions.continuous import PositiveContinuous
from pymc.distributions.shape_utils import rv_size_is_none
from pymc.distributions.dist_math import check_parameters


class GeneralizedGamma(PositiveContinuous):
r"""
Generalized Gamma log-likelihood.

The pdf of this distribution is

.. math::
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved

f(x \mid \alpha, p, \lambda) =
\frac{ p\lambda^{-1} (x/\lambda)^{\alpha - 1} e^{-(x/\lambda)^p}}
{\Gamma(\alpha/p)}

.. plot::
:context: close-figs

import matplotlib.pyplot as plt
import numpy as np
import scipy.stats as st
import arviz as az
plt.style.use('arviz-darkgrid')
x = np.linspace(1, 50, 1000)
alphas = [1,1,2,2]
ps = [1, 2, 4, 4]
lambds = [10., 10., 10., 20.]
for alpha, p, lambd in zip(alphas, ps, lambds):
pdf = st.gengamma.pdf(x, alpha/p, p, scale=lambd)
plt.plot(x, pdf, label=r'$\alpha$ = {}, $p$ = {}, $\lambda$ = {}'.format(alpha, p, lambd))
plt.xlabel('x', fontsize=12)
plt.ylabel('f(x)', fontsize=12)
plt.legend(loc=1)
plt.show()

======== ==========================================
Support :math:`x \in [0, \infty)`
Mean :math:`\lambda \frac{\Gamma((\alpha+1)/p)}{\Gamma(\alpha/p)}`
Variance :math:`\lambda^2 \left( \frac{\Gamma((\alpha+2)/p)}{\Gamma(\alpha/p)} - \left(\frac{\Gamma((\alpha+1)/p)}{\Gamma(\alpha/p)}\right)^2 \right)`
======== ==========================================

Parameters
----------
alpha : tensor_like of float, optional
Shape parameter :math:`\alpha` (``alpha`` > 0).
Defaults to 1.
p : tensor_like of float, optional
Additional shape parameter `p` (p > 0).
Defaults to 1.
lambd : tensor_like of float, optional
Scale parameter :math:`\lambda` (lambd > 0).
Defaults to 1.

Examples
--------

.. code-block:: python
with pm.Model():
x = pm.GeneralizedGamma('x', alpha=1, p=2, lambd=5)
"""
rv_op = gengamma

@classmethod
def dist(cls, alpha, p, lambd, **kwargs):
alpha = at.as_tensor_variable(floatX(alpha))
p = at.as_tensor_variable(floatX(p))
lambd = at.as_tensor_variable(floatX(lambd))

return super().dist([alpha, p, lambd], **kwargs)

def moment(rv, size, alpha, p, lambd):
alpha, p, lambd = at.broadcast_arrays(alpha, p, lambd)
moment = lambd * at.gamma((alpha + 1) / p) / at.gamma(alpha / p)
if not rv_size_is_none(size):
moment = at.full(size, moment)
return moment

def logp(
value,
alpha: TensorVariable,
p: TensorVariable,
lambd: TensorVariable,
) -> TensorVariable:
"""
Calculate log-probability of Generalized Gamma distribution at specified value.
Parameters
----------
value : tensor_like of float
Value(s) for which log-probability is calculated. If the log probabilities for multiple
values are desired the values must be provided in a numpy array or Aesara tensor.
alpha : tensor_like of float
Shape parameter (alpha > 0).
p : tensor_like of float
Shape parameter (p > 0).
lambd : tensor_like of float
Scale parameter (lambd > 0).
Returns
-------
TensorVariable
"""
logp_expression = (
at.log(p)
- at.log(lambd)
+ (alpha - 1) * at.log(value / lambd)
- (value / lambd) ** p
- at.gammaln(alpha / p)
)

bounded_logp_expression = at.switch(
at.gt(value, 0),
logp_expression,
-np.inf,
)

return check_parameters(
bounded_logp_expression,
alpha > 0,
p > 0,
lambd > 0,
msg="alpha > 0, p > 0, lambd > 0",
)

def logcdf(
value,
alpha: TensorVariable,
p: TensorVariable,
lambd: TensorVariable,
) -> TensorVariable:
"""
Compute the log of the cumulative distribution function for GeneralizedGamma
distribution at the specified value.
Parameters
----------
value : tensor_like of float
Value(s) for which log CDF is calculated. If the log CDF for multiple
values are desired the values must be provided in a numpy array or Aesara tensor.
alpha : tensor_like of float
Shape parameter (alpha > 0).
p : tensor_like of float
Shape parameter (p > 0).
lambd : tensor_like of float
Scale parameter (lambd > 0).
Returns
-------
TensorVariable
"""
logcdf_expression = at.log(at.gammainc(alpha / p, (value / lambd) ** p))
Copy link
Author

@kylejcaron kylejcaron Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with the math, but I am concerned the logcdf function may have some numerical instability. I attempted to fit a model to simulated data with the pm.Censored API and it was a bit slow, and had a few divergences, but it did correctly identify the true parameters

(I tested this with the latest aeppl fix for the pm.Censored API. code is in the last section of this notebook).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see some obvious reparametrization. Unless there is some loggammaincc function we could use?

If you are concerned I would suggest you do a grid evaluation of the logcdf and evaluate against an arbitrary precision software like Mathematica.

That will confirm or dispell your concerns of stability, and then we can try to look for an alternative. It could also be the instability is not in the function itself but its gradients.

Do you have a minimum example where you found the sampling issues that you can share?

Copy link
Author

@kylejcaron kylejcaron Aug 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Mathematica a paid software? I had trouble finding anything to test against that would guarantee numerical precision, I'm pretty new to this

here's a minimum (kind of) example where there are sampling issues (very slow sampling. The distribution is quick when not using the pm.Censored api)

import pandas as pd
import scipy.stats as stats
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pymc_experimental as pmx


SEED = 99


# ########################
# code
# ########################

def sim_data(true_paramsA, true_paramsB, SEED):
	
	np.random.seed(SEED)

	# 2 groups that generate data
	a = stats.gengamma(true_paramsA[0]/true_paramsA[1], true_paramsA[1], scale=true_paramsA[2])
	b = stats.gengamma(true_paramsB[0]/true_paramsB[1], true_paramsB[1], scale=true_paramsB[2])
	N = 100 # obs per group

	# create a dataset of N obs from each group
	y_true = np.r_[
	        a.rvs(N),
	        b.rvs(N)
	    ]

	# randomly censor the dataset for survival analysis
	cens_time = np.random.uniform(0, y_true.max(), size=N*2)

	data = (
		pd.DataFrame({
		"group":[0]*N + [1]*N,
		"time": y_true})
		# adjust the dataset to censor observations
		## indicates if an event hasnt occurred yet (cens=1)
		.assign(cens = lambda d: np.where(d.time <= cens_time, 0, 1) )
		## indicates the latest time observed for each record
		.assign(time = lambda d: np.where(d.cens==1, cens_time, d.time) )
	)

	return data

def fit_model(data):

	cens_ = np.where(data.cens==1, data.time, np.inf)
	COORDS = {"group": ["A","B"]}
	group_ = data.group.values
	y_ = data.time.values

	with pm.Model(coords=COORDS) as model:

		# weakly informative priors for each parameter
		log_alpha = pm.Normal("log_alpha", 2, 0.75, dims="group")
		log_p = pm.Normal("log_p", 0.45, 0.3, dims="group")
		log_lambd = pm.Normal("log_lambd", 4.1, 0.4, dims="group")

		# helper vars
		alpha = pm.Deterministic("alpha", pm.math.exp(log_alpha), dims="group")
		p = pm.Deterministic("p", pm.math.exp(log_p), dims="group")
		lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")

		# latent variable needed for censored fitting
		y_latent = pmx.GeneralizedGamma.dist( alpha[group_], p[group_], lambd[group_] )

		# likelihood of the observed data
		obs = pm.Censored("obs", y_latent, lower=None, upper=cens_, observed=y_)

		# fit
		idata = pm.sample(init="adapt_diag", random_seed=SEED)
		idata.extend(pm.sample_prior_predictive())
		idata.extend(pm.sample_posterior_predictive(idata))

	return model, idata


def main():
	true_paramsA = [2, 2, 60]
	true_paramsB = [20, 2, 60]
	data = sim_data(true_paramsA, true_paramsB, SEED)
	model, idata = fit_model(data)


if __name__ == '__main__':
	
	main()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the main version of Aeppl? We haven't done a new release yet, but there was a recent patch that makes the gradients of censored logps more stable: aesara-devs/aeppl#156

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Mathematica a paid software?

I might have jumbled the names. In the past I've used https://www.wolframcloud.com/ with a free account.

Copy link
Author

@kylejcaron kylejcaron Aug 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that everything is fine and samples quickly without censored data and without the censored api involved (i.e. just regular fitting on non-censored data)

Also, sampling might've actually gotten slower with less censoring (when using the censored api)? I didnt sample to completion, I just looked at the initial guess after 1 minute of sampling

  • Original, 46% (96/200 obs censored) - 40 minutes of sampling estimated
  • 25% censoring (50/200 censored) - 50 minutes
  • 2.5% (5/200 censored) - 1 hour

For testing other distributions:

  • Censored normal data sampled quickly for me (<5 seconds)
  • Censored Gamma was really slow (estimated to be > 1 hour. Also got slower with less censoring). Uncensored fitting on same data was fine.

The fact that the fitting gets slower with less censoring seems really off to me, and also surprised to see the Gamma distribution suffers from similar issues

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: just looked at the log-cdf of the gamma distribution and its implemented very similar to how I implemented the generalized gamma, so makes sense its slow - I guess this points to the gammainc_der op being an issue more than the parameterization

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: just looked at the log-cdf of the gamma distribution and its implemented very similar to how I implemented the generalized gamma, so makes sense its slow - I guess this points to the gammainc_der op being an issue more than the parameterization

Yes that's quite possible

Copy link
Author

@kylejcaron kylejcaron Aug 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that fitting with the distribution works well without the Censored API and the regular Gamma distribution has the same problem in PyMC anyway, is it fair to call this ok and proceed with the PR (once tests are passing) since its experimental anyway?

It'd atleast be a good starting point to be improved upon, and the current parameterization I chose is fairly interpretable, and I still have some use cases for it atleast (if it ever plays nicely with pm.censored I'll have a ton of use cases for it)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely! I'll try to review it again soon.



bounded_logcdf_expression = at.switch(
at.gt(value, 0),
logcdf_expression,
-np.inf,
)

return check_parameters(
bounded_logcdf_expression,
alpha > 0,
p > 0,
lambd > 0,
msg="alpha > 0, p > 0, lambd > 0",
)

21 changes: 21 additions & 0 deletions pymc_experimental/tests/test_distributions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest
import scipy.stats
import scipy.stats.distributions as sp
from pymc_experimental.distributions import GeneralizedGamma
from pymc.tests.test_distributions import Rplus, Rplusbig, check_logp, check_logcdf


class TestMatchesScipy:
def test_generalized_gamma(self):
check_logp(
GeneralizedGamma,
Rplus,
{"alpha": Rplusbig, "p": Rplusbig, "lambd": Rplusbig},
lambda value, alpha, p, lambd: sp.gengamma.logpdf(value, a=alpha / p, c=p, scale=lambd),
)
check_logcdf(
GeneralizedGamma,
Rplus,
{"alpha": Rplusbig, "p": Rplusbig, "lambd": Rplusbig},
lambda value, alpha, p, lambd: sp.gengamma.logcdf(value, a=alpha / p, c=p, scale=lambd),
kylejcaron marked this conversation as resolved.
Show resolved Hide resolved
)
35 changes: 35 additions & 0 deletions pymc_experimental/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import pytest

import numpy as np
import scipy.special as special
import pymc as pm
from pymc import Model
from pymc.tests.test_distributions_moments import assert_moment_is_expected
from pymc_experimental.distributions import GeneralizedGamma


@pytest.mark.parametrize(
"alpha, p, lambd, size, expected",
[
(1, 1, 2, None, 2),
(1, 1, 2, 5, np.full(5, 2)),
(1, 1, np.arange(1, 6), None, np.arange(1, 6)),
(
np.arange(1, 6),
2 * np.arange(1, 6),
10,
(2, 5),
np.full(
(2, 5),
10
* special.gamma((np.arange(1, 6) + 1) / (np.arange(1, 6) * 2))
/ special.gamma(np.arange(1, 6) / (np.arange(1, 6) * 2)),
),
),
],
)
def test_generalized_gamma_moment(alpha, p, lambd, size, expected):
with Model() as model:
GeneralizedGamma("x", alpha=alpha, p=p, lambd=lambd, size=size)
assert_moment_is_expected(model, expected)

22 changes: 22 additions & 0 deletions pymc_experimental/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import pytest

import numpy as np
import scipy.special as special
import pymc_experimental as pmx
from pymc.tests.test_distributions_random import (
BaseTestDistributionRandom,
seeded_scipy_distribution_builder,
)


class TestGeneralizedGamma(BaseTestDistributionRandom):
pymc_dist = pmx.GeneralizedGamma
pymc_dist_params = {"alpha": 2.0, "p": 3.0, "lambd": 5.0}
expected_rv_op_params = {"alpha": 2.0, "p": 3.0, "lambd": 5.0}
reference_dist_params = {"a": 2.0 / 3.0, "c": 3.0, "scale": 5.0}
reference_dist = seeded_scipy_distribution_builder("gengamma")
checks_to_run = [
"check_pymc_params_match_rv_op",
"check_pymc_draws_match_reference",
]