Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added generalized gamma distribution #59

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

kylejcaron
Copy link

@kylejcaron kylejcaron commented Jul 28, 2022

What is this PR about?
This PR implements the generalized gamma distribution, using the parameterization GG(alpha, p, lambda).

This corresponds with the following PR in aesara.

It also was originally this PR in pymc, but @ricardoV94 recommended to start off with pymc-experimental which makes sense to me especially since I'm having trouble using it with the pm.Censored API

-------
TensorVariable
"""
logcdf_expression = at.log(at.gammainc(alpha / p, (value / lambd) ** p))
Copy link
Author

@kylejcaron kylejcaron Jul 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not very familiar with the math, but I am concerned the logcdf function may have some numerical instability. I attempted to fit a model to simulated data with the pm.Censored API and it was a bit slow, and had a few divergences, but it did correctly identify the true parameters

(I tested this with the latest aeppl fix for the pm.Censored API. code is in the last section of this notebook).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see some obvious reparametrization. Unless there is some loggammaincc function we could use?

If you are concerned I would suggest you do a grid evaluation of the logcdf and evaluate against an arbitrary precision software like Mathematica.

That will confirm or dispell your concerns of stability, and then we can try to look for an alternative. It could also be the instability is not in the function itself but its gradients.

Do you have a minimum example where you found the sampling issues that you can share?

Copy link
Author

@kylejcaron kylejcaron Aug 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Mathematica a paid software? I had trouble finding anything to test against that would guarantee numerical precision, I'm pretty new to this

here's a minimum (kind of) example where there are sampling issues (very slow sampling. The distribution is quick when not using the pm.Censored api)

import pandas as pd
import scipy.stats as stats
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import pymc_experimental as pmx


SEED = 99


# ########################
# code
# ########################

def sim_data(true_paramsA, true_paramsB, SEED):
	
	np.random.seed(SEED)

	# 2 groups that generate data
	a = stats.gengamma(true_paramsA[0]/true_paramsA[1], true_paramsA[1], scale=true_paramsA[2])
	b = stats.gengamma(true_paramsB[0]/true_paramsB[1], true_paramsB[1], scale=true_paramsB[2])
	N = 100 # obs per group

	# create a dataset of N obs from each group
	y_true = np.r_[
	        a.rvs(N),
	        b.rvs(N)
	    ]

	# randomly censor the dataset for survival analysis
	cens_time = np.random.uniform(0, y_true.max(), size=N*2)

	data = (
		pd.DataFrame({
		"group":[0]*N + [1]*N,
		"time": y_true})
		# adjust the dataset to censor observations
		## indicates if an event hasnt occurred yet (cens=1)
		.assign(cens = lambda d: np.where(d.time <= cens_time, 0, 1) )
		## indicates the latest time observed for each record
		.assign(time = lambda d: np.where(d.cens==1, cens_time, d.time) )
	)

	return data

def fit_model(data):

	cens_ = np.where(data.cens==1, data.time, np.inf)
	COORDS = {"group": ["A","B"]}
	group_ = data.group.values
	y_ = data.time.values

	with pm.Model(coords=COORDS) as model:

		# weakly informative priors for each parameter
		log_alpha = pm.Normal("log_alpha", 2, 0.75, dims="group")
		log_p = pm.Normal("log_p", 0.45, 0.3, dims="group")
		log_lambd = pm.Normal("log_lambd", 4.1, 0.4, dims="group")

		# helper vars
		alpha = pm.Deterministic("alpha", pm.math.exp(log_alpha), dims="group")
		p = pm.Deterministic("p", pm.math.exp(log_p), dims="group")
		lambd = pm.Deterministic("lambd", pm.math.exp(log_lambd), dims="group")

		# latent variable needed for censored fitting
		y_latent = pmx.GeneralizedGamma.dist( alpha[group_], p[group_], lambd[group_] )

		# likelihood of the observed data
		obs = pm.Censored("obs", y_latent, lower=None, upper=cens_, observed=y_)

		# fit
		idata = pm.sample(init="adapt_diag", random_seed=SEED)
		idata.extend(pm.sample_prior_predictive())
		idata.extend(pm.sample_posterior_predictive(idata))

	return model, idata


def main():
	true_paramsA = [2, 2, 60]
	true_paramsB = [20, 2, 60]
	data = sim_data(true_paramsA, true_paramsB, SEED)
	model, idata = fit_model(data)


if __name__ == '__main__':
	
	main()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the main version of Aeppl? We haven't done a new release yet, but there was a recent patch that makes the gradients of censored logps more stable: aesara-devs/aeppl#156

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is Mathematica a paid software?

I might have jumbled the names. In the past I've used https://www.wolframcloud.com/ with a free account.

Copy link
Author

@kylejcaron kylejcaron Aug 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing to note is that everything is fine and samples quickly without censored data and without the censored api involved (i.e. just regular fitting on non-censored data)

Also, sampling might've actually gotten slower with less censoring (when using the censored api)? I didnt sample to completion, I just looked at the initial guess after 1 minute of sampling

  • Original, 46% (96/200 obs censored) - 40 minutes of sampling estimated
  • 25% censoring (50/200 censored) - 50 minutes
  • 2.5% (5/200 censored) - 1 hour

For testing other distributions:

  • Censored normal data sampled quickly for me (<5 seconds)
  • Censored Gamma was really slow (estimated to be > 1 hour. Also got slower with less censoring). Uncensored fitting on same data was fine.

The fact that the fitting gets slower with less censoring seems really off to me, and also surprised to see the Gamma distribution suffers from similar issues

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: just looked at the log-cdf of the gamma distribution and its implemented very similar to how I implemented the generalized gamma, so makes sense its slow - I guess this points to the gammainc_der op being an issue more than the parameterization

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update: just looked at the log-cdf of the gamma distribution and its implemented very similar to how I implemented the generalized gamma, so makes sense its slow - I guess this points to the gammainc_der op being an issue more than the parameterization

Yes that's quite possible

Copy link
Author

@kylejcaron kylejcaron Aug 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that fitting with the distribution works well without the Censored API and the regular Gamma distribution has the same problem in PyMC anyway, is it fair to call this ok and proceed with the PR (once tests are passing) since its experimental anyway?

It'd atleast be a good starting point to be improved upon, and the current parameterization I chose is fairly interpretable, and I still have some use cases for it atleast (if it ever plays nicely with pm.censored I'll have a ton of use cases for it)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely! I'll try to review it again soon.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great!

There are some type-hints issues but otherwise everything looks correct.

I left some suggestions to investigate the logcdf concerns but I am afraid I was not of much help.

pymc_experimental/distributions/continuous.py Show resolved Hide resolved
pymc_experimental/distributions/continuous.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/continuous.py Outdated Show resolved Hide resolved
pymc_experimental/distributions/continuous.py Outdated Show resolved Hide resolved
-------
TensorVariable
"""
logcdf_expression = at.log(at.gammainc(alpha / p, (value / lambd) ** p))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see some obvious reparametrization. Unless there is some loggammaincc function we could use?

If you are concerned I would suggest you do a grid evaluation of the logcdf and evaluate against an arbitrary precision software like Mathematica.

That will confirm or dispell your concerns of stability, and then we can try to look for an alternative. It could also be the instability is not in the function itself but its gradients.

Do you have a minimum example where you found the sampling issues that you can share?

pymc_experimental/tests/test_distributions.py Show resolved Hide resolved
@ricardoV94
Copy link
Member

@OriolAbril Can you give some guidance on how to include the distribution in the doc pages?

@OriolAbril
Copy link
Member

It needs to be added to https://github.com/pymc-devs/pymc-experimental/blob/main/docs/api_reference.rst. Which also needs some extra changes (along with maybe some changes to the imports and module hierarchy) I think. For this PR only:

:mod:`pymc_experimental.distributions`
=============================

.. automodule:: pymc_experimental.distributions
   :members: histogram_approximation, GeneralizedGamma

In general, is the plan for users to use pymc_experimental.utils.prior.prior_from_idata or to use pymc_experimental.utils.prior_from_idata? Depending on what is the plan, the imports and structure when generating the API docs should be updated. The docs currently tell users to import prior_from_idata from pmx.utils.prior

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants