Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Symmetric and Asymmetric Multivariate Laplace distributions #389

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.17.0 # CI was failing to resolve
- pymc>=5.18.0 # CI was failing to resolve
- blackjax
- scikit-learn
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ dependencies:
- xhistogram
- statsmodels
- pip:
- pymc>=5.17.0 # CI was failing to resolve
- pymc>=5.18.0 # CI was failing to resolve
- blackjax
- scikit-learn
2 changes: 2 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ Distributions
GeneralizedPoisson
BetaNegativeBinomial
GenExtreme
MvAsymmetricLaplace
MvLaplace
R2D2M2CP
Skellam
histogram_approximation
Expand Down
5 changes: 4 additions & 1 deletion pymc_experimental/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@
Skellam,
)
from pymc_experimental.distributions.histogram_utils import histogram_approximation
from pymc_experimental.distributions.multivariate import R2D2M2CP
from pymc_experimental.distributions.multivariate.laplace import MvAsymmetricLaplace, MvLaplace
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain

__all__ = [
"BetaNegativeBinomial",
"DiscreteMarkovChain",
"GeneralizedPoisson",
"GenExtreme",
"MvAsymmetricLaplace",
"MvLaplace",
"R2D2M2CP",
"Skellam",
"histogram_approximation",
Expand Down
3 changes: 0 additions & 3 deletions pymc_experimental/distributions/multivariate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP

__all__ = ["R2D2M2CP"]
276 changes: 276 additions & 0 deletions pymc_experimental/distributions/multivariate/laplace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
import numpy as np
import pytensor.tensor as pt
import scipy

from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable, _support_point
from pymc.distributions.moments.means import _mean
from pymc.distributions.multivariate import (
_logdet_from_cholesky,
nan_lower_cholesky,
quaddist_chol,
quaddist_matrix,
solve_lower,
)
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
from pymc.logprob.basic import _logprob
from pymc.pytensorf import normalize_rng_param
from pytensor.gradient import grad_not_implemented
from pytensor.scalar import BinaryScalarOp, upgrade_to_float
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.utils import normalize_size_param


class Kv(BinaryScalarOp):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll move this to PyTensor before merging this PR

"""
Modified Bessel function of the second kind of real order v.
"""

nfunc_spec = ("scipy.special.kv", 2, 1)

@staticmethod
def st_impl(v, x):
return scipy.special.kv(v, x)

def impl(self, v, x):
return self.st_impl(v, x)

def L_op(self, inputs, outputs, output_grads):
v, x = inputs
[out] = outputs
[g_out] = output_grads
dx = -(v / x) * out - self.kv(v - 1, x)
return [grad_not_implemented(self, 0, v), g_out * dx]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


kv = Elemwise(Kv(upgrade_to_float, name="kv"))


class MvLaplaceRV(SymbolicRandomVariable):
name = "multivariate_laplace"
extended_signature = "[rng],[size],(m),(m,m)->[rng],(m)"
_print_name = ("MultivariateLaplace", "\\operatorname{MultivariateLaplace}")

@classmethod
def rv_op(cls, mu, cov, *, size=None, rng=None):
mu = pt.as_tensor(mu)
cov = pt.as_tensor(cov)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

assert mu.type.ndim >= 1
assert cov.type.ndim >= 2

if rv_size_is_none(size):
size = implicit_size_from_params(mu, cov, ndims_params=(1, 2))

next_rng, e = pt.random.exponential(size=size, rng=rng).owner.outputs
next_rng, z = pt.random.multivariate_normal(
mean=pt.zeros(mu.shape[-1]), cov=cov, size=size, rng=next_rng
).owner.outputs
rv = mu + pt.sqrt(e)[..., None] * z

return cls(
inputs=[rng, size, mu, cov],
outputs=[next_rng, rv],
)(rng, size, mu, cov)


class MvLaplace(Continuous):
r"""Multivariate (Symmetric) Laplace distribution.

The pdf of this distribution is

.. math::

pdf(x \mid \mu, \Sigma) =
\frac{2}{(2\pi)^{k/2} |\Sigma|^{1/2}}
( \frac{(x-\mu)'\Sigma^{-1}(x-mu)}{2} )^{v/2}
\K_v (\sqrt{2(x-\mu)' \Sigma^{-1} (x - \mu)}})

where :math:`v = 1 - k/2` and :math:`\K_v` is the modified Bessel function of the second kind.

======== ==========================
Support :math:`x \in \mathbb{R}^k`
Mean :math:`\mu`
Variance :math:`\Sigma`
======== ==========================

Parameters
----------
mu : tensor_like of float
Location.
cov : tensor_like of float, optional
Covariance matrix. Exactly one of cov, tau, or chol is needed.
tau : tensor_like of float, optional
Precision matrix. Exactly one of cov, tau, or chol is needed.
chol : tensor_like of float, optional
Cholesky decomposition of covariance matrix. Exactly one of cov,
tau, or chol is needed.
lower: bool, default=True
Whether chol is the lower tridiagonal cholesky factor.
"""

rv_type = MvLaplaceRV
rv_op = MvLaplaceRV.rv_op

@classmethod
def dist(cls, mu=0, cov=None, *, tau=None, chol=None, lower=True, **kwargs):
cov = quaddist_matrix(cov, chol, tau, lower)

mu = pt.atleast_1d(pt.as_tensor_variable(mu))
if mu.type.broadcastable[-1] and not cov.type.broadcastable[-1]:
mu, _ = pt.broadcast_arrays(mu, cov[..., -1])
return super().dist([mu, cov], **kwargs)


class MvAsymmetricLaplaceRV(SymbolicRandomVariable):
name = "multivariate_asymmetric_laplace"
extended_signature = "[rng],[size],(m),(m,m)->[rng],(m)"
_print_name = ("MultivariateAsymmetricLaplace", "\\operatorname{MultivariateAsymmetricLaplace}")

@classmethod
def rv_op(cls, mu, cov, *, size=None, rng=None):
mu = pt.as_tensor(mu)
cov = pt.as_tensor(cov)
rng = normalize_rng_param(rng)
size = normalize_size_param(size)

assert mu.type.ndim >= 1
assert cov.type.ndim >= 2

if rv_size_is_none(size):
size = implicit_size_from_params(mu, cov, ndims_params=(1, 2))

next_rng, e = pt.random.exponential(size=size, rng=rng).owner.outputs
next_rng, z = pt.random.multivariate_normal(
mean=pt.zeros(mu.shape[-1]), cov=cov, size=size, rng=next_rng
).owner.outputs
e = e[..., None]
rv = e * mu + pt.sqrt(e) * z

return cls(
inputs=[rng, size, mu, cov],
outputs=[next_rng, rv],
)(rng, size, mu, cov)


class MvAsymmetricLaplace(Continuous):
r"""Multivariate Asymmetric Laplace distribution.

The pdf of this distribution is

.. math::

pdf(x \mid \mu, \Sigma) =
\frac{2}{(2\pi)^{k/2} |\Sigma|^{1/2}}
( \frac{(x-\mu)'\Sigma^{-1}(x-mu)}{2} )^{v/2}
\K_v (\sqrt{2(x-\mu)' \Sigma^{-1} (x - \mu)}})

where :math:`v = 1 - k/2` and :math:`\K_v` is the modified Bessel function of the second kind.

======== ==========================
Support :math:`x \in \mathbb{R}^k`
Mean :math:`\mu`
Variance :math:`\Sigma + \mu' \mu`
======== ==========================

Parameters
----------
mu : tensor_like of float
Location.
cov : tensor_like of float, optional
Covariance matrix. Exactly one of cov, tau, or chol is needed.
tau : tensor_like of float, optional
Precision matrix. Exactly one of cov, tau, or chol is needed.
chol : tensor_like of float, optional
Cholesky decomposition of covariance matrix. Exactly one of cov,
tau, or chol is needed.
lower: bool, default=True
Whether chol is the lower tridiagonal cholesky factor.
"""

rv_type = MvAsymmetricLaplaceRV
rv_op = MvAsymmetricLaplaceRV.rv_op

@classmethod
def dist(cls, mu=0, cov=None, *, tau=None, chol=None, lower=True, **kwargs):
cov = quaddist_matrix(cov, chol, tau, lower)

mu = pt.atleast_1d(pt.as_tensor_variable(mu))
if mu.type.broadcastable[-1] and not cov.type.broadcastable[-1]:
mu, _ = pt.broadcast_arrays(mu, cov[..., -1])
return super().dist([mu, cov], **kwargs)


@_logprob.register(MvLaplaceRV)
def mv_laplace_logp(op, values, rng, size, mu, cov, **kwargs):
[value] = values
quaddist, logdet, posdef = quaddist_chol(value, mu, cov)

k = value.shape[-1].astype("floatX")
norm = np.log(2) - (k / 2) * np.log(2 * np.pi) - logdet

v = 1 - (k / 2)
kernel = ((v / 2) * pt.log(quaddist / 2)) + pt.log(kv(v, pt.sqrt(2 * quaddist)))

logp_val = norm + kernel
return check_parameters(logp_val, posdef, msg="posdef scale")


@_logprob.register(MvAsymmetricLaplaceRV)
def mv_asymmetric_laplace_logp(op, values, rng, size, mu, cov, **kwargs):
[value] = values

chol_cov = nan_lower_cholesky(cov)
logdet, posdef = _logdet_from_cholesky(chol_cov)

# solve_triangular will raise if there are nans
# (which happens if the cholesky fails)
chol_cov = pt.switch(posdef[..., None, None], chol_cov, 1)

solve_x = solve_lower(chol_cov, value, b_ndim=1)
solve_mu = solve_lower(chol_cov, mu, b_ndim=1)

x_quaddist = (solve_x**2).sum(-1)
mu_quaddist = (solve_mu**2).sum(-1)
x_mu_quaddist = (value * solve_mu).sum(-1)

k = value.shape[-1].astype("floatX")
norm = np.log(2) - (k / 2) * np.log(2 * np.pi) - logdet

v = 1 - (k / 2)
kernel = (
x_mu_quaddist
+ ((v / 2) * (pt.log(x_quaddist) - pt.log(2 + mu_quaddist)))
+ pt.log(kv(v, pt.sqrt((2 + mu_quaddist) * x_quaddist)))
)

logp_val = norm + kernel
return check_parameters(logp_val, posdef, msg="posdef scale")


@_mean.register(MvLaplaceRV)
@_mean.register(MvAsymmetricLaplaceRV)
def mv_laplace_mean(op, rv, rng, size, mu, cov):
if rv_size_is_none(size):
bcast_mu, _ = pt.random.utils.broadcast_params([mu, cov], ndims_params=[1, 2])
else:
bcast_mu = pt.broadcast_to(mu, pt.concatenate([size, [mu.shape[-1]]]))
return bcast_mu


@_support_point.register(MvLaplaceRV)
@_support_point.register(MvAsymmetricLaplaceRV)
def mv_laplace_support_point(op, rv, rng, size, mu, cov):
# We have a 0 * inf when value = mu. I assume density is infinite, which isn't a good starting point.
point = mu + 1
if rv_size_is_none(size):
bcast_point, _ = pt.random.utils.broadcast_params([point, cov], ndims_params=[1, 2])
else:
bcast_shape = pt.concatenate([size, [point.shape[-1]]])
bcast_point = pt.broadcast_to(point, bcast_shape)
return bcast_point
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pymc>=5.17.0
pymc>=5.18.0
scikit-learn
19 changes: 0 additions & 19 deletions tests/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +0,0 @@
# Copyright 2022 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from pymc_experimental.distributions import histogram_utils
from pymc_experimental.distributions.histogram_utils import histogram_approximation

__all__ = ["histogram_utils", "histogram_approximation"]
Empty file.
Loading
Loading