From 87d4aea8d3c3003e641e0de6d5fcb861a224c189 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carsten=20J=C3=B8rgensen?= Date: Mon, 1 Jul 2024 17:56:09 +0200 Subject: [PATCH] Implement Laplace (quadratic) approximation (#345) * First draft of quadratic approximation * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Review comments incorporated * License and copyright information added * Only add additional data to inferencedata when chains!=0 * Raise error if Hessian is singular * Replace for loop with call to remove_value_transforms * Pass model directly when finding MAP and the Hessian * Update pymc_experimental/inference/laplace.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Remove chains from public parameters for Laplace approx method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Parameter draws is not optional with default value 1000 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add warning if numbers of variables in vars does not equal number of model variables * Update version.txt * `shock_size` should never be scalar * Blackjax API change * Handle latest PyMC/PyTensor breaking changes * Temporarily mark two tests as xfail * More bugfixes for statespace (#346) * Allow forward sampling of statespace models in JAX mode Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs * Adjust test suite to reflect API changes Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests * Add JAX test suite * Bug-fixes and changes to statespace distributions Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal` * Re-run example notebooks * Add helper function to sample prior/posterior statespace matrices * fix tests * Wrap jax MvNormal rewrite in try/except block * Don't use `action` keyword in `catch_warnings` * Skip JAX test if `numpyro` is not installed * Handle batch dims on `SequenceMvNormal` * Remove unused batch_dim logic in SequenceMvNormal * Restore `get_support_shape_1d` import * Fix failing test case for laplace --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> Co-authored-by: Jesse Grabowski Co-authored-by: Ricardo Vieira Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pymc_experimental/inference/fit.py | 8 +- pymc_experimental/inference/laplace.py | 190 ++++++++++++++++++++++++ pymc_experimental/tests/test_laplace.py | 137 +++++++++++++++++ 3 files changed, 334 insertions(+), 1 deletion(-) create mode 100644 pymc_experimental/inference/laplace.py create mode 100644 pymc_experimental/tests/test_laplace.py diff --git a/pymc_experimental/inference/fit.py b/pymc_experimental/inference/fit.py index 565b59e3..71dfb9f8 100644 --- a/pymc_experimental/inference/fit.py +++ b/pymc_experimental/inference/fit.py @@ -21,7 +21,7 @@ def fit(method, **kwargs): ---------- method : str Which inference method to run. - Supported: pathfinder + Supported: pathfinder or laplace kwargs are passed on. @@ -38,3 +38,9 @@ def fit(method, **kwargs): from pymc_experimental.inference.pathfinder import fit_pathfinder return fit_pathfinder(**kwargs) + + if method == "laplace": + + from pymc_experimental.inference.laplace import laplace + + return laplace(**kwargs) diff --git a/pymc_experimental/inference/laplace.py b/pymc_experimental/inference/laplace.py new file mode 100644 index 00000000..1508b6e8 --- /dev/null +++ b/pymc_experimental/inference/laplace.py @@ -0,0 +1,190 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from collections.abc import Sequence +from typing import Optional + +import arviz as az +import numpy as np +import pymc as pm +import xarray as xr +from arviz import dict_to_dataset +from pymc.backends.arviz import ( + coords_and_dims_for_inferencedata, + find_constants, + find_observations, +) +from pymc.model.transform.conditioning import remove_value_transforms +from pymc.util import RandomSeed +from pytensor import Variable + + +def laplace( + vars: Sequence[Variable], + draws: Optional[int] = 1000, + model=None, + random_seed: Optional[RandomSeed] = None, + progressbar=True, +): + """ + Create a Laplace (quadratic) approximation for a posterior distribution. + + This function generates a Laplace approximation for a given posterior distribution using a specified + number of draws. This is useful for obtaining a parametric approximation to the posterior distribution + that can be used for further analysis. + + Parameters + ---------- + vars : Sequence[Variable] + A sequence of variables for which the Laplace approximation of the posterior distribution + is to be created. + draws : Optional[int] with default=1_000 + The number of draws to sample from the posterior distribution for creating the approximation. + For draws=None only the fit of the Laplace approximation is returned + model : object, optional, default=None + The model object that defines the posterior distribution. If None, the default model will be used. + random_seed : Optional[RandomSeed], optional, default=None + An optional random seed to ensure reproducibility of the draws. If None, the draws will be + generated using the current random state. + progressbar: bool, optional defaults to True + Whether to display a progress bar in the command line. + + Returns + ------- + arviz.InferenceData + An `InferenceData` object from the `arviz` library containing the Laplace + approximation of the posterior distribution. The inferenceData object also + contains constant and observed data as well as deterministic variables. + InferenceData also contains a group 'fit' with the mean and covariance + for the Laplace approximation. + + Examples + -------- + + >>> import numpy as np + >>> import pymc as pm + >>> import arviz as az + >>> from pymc_experimental.inference.laplace import laplace + >>> y = np.array([2642, 3503, 4358]*10) + >>> with pm.Model() as m: + >>> logsigma = pm.Uniform("logsigma", 1, 100) + >>> mu = pm.Uniform("mu", -10000, 10000) + >>> yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + >>> idata = laplace([mu, logsigma], model=m) + + Notes + ----- + This method of approximation may not be suitable for all types of posterior distributions, + especially those with significant skewness or multimodality. + + See Also + -------- + fit : Calling the inference function 'fit' like pmx.fit(method="laplace", vars=[mu, logsigma], model=m) + will forward the call to 'laplace'. + + """ + + rng = np.random.default_rng(seed=random_seed) + + transformed_m = pm.modelcontext(model) + + if len(vars) != len(transformed_m.free_RVs): + warnings.warn( + "Number of variables in vars does not eqaul the number of variables in the model.", + UserWarning, + ) + + map = pm.find_MAP(vars=vars, progressbar=progressbar, model=transformed_m) + + # See https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html + untransformed_m = remove_value_transforms(transformed_m) + untransformed_vars = [untransformed_m[v.name] for v in vars] + hessian = pm.find_hessian(point=map, vars=untransformed_vars, model=untransformed_m) + + if np.linalg.det(hessian) == 0: + raise np.linalg.LinAlgError("Hessian is singular.") + + cov = np.linalg.inv(hessian) + mean = np.concatenate([np.atleast_1d(map[v.name]) for v in vars]) + + chains = 1 + + if draws is not None: + samples = rng.multivariate_normal(mean, cov, size=(chains, draws)) + + data_vars = {} + for i, var in enumerate(vars): + data_vars[str(var)] = xr.DataArray(samples[:, :, i], dims=("chain", "draw")) + + coords = {"chain": np.arange(chains), "draw": np.arange(draws)} + ds = xr.Dataset(data_vars, coords=coords) + + idata = az.convert_to_inference_data(ds) + idata = addDataToInferenceData(model, idata, progressbar) + else: + idata = az.InferenceData() + + idata = addFitToInferenceData(vars, idata, mean, cov) + + return idata + + +def addFitToInferenceData(vars, idata, mean, covariance): + coord_names = [v.name for v in vars] + # Convert to xarray DataArray + mean_dataarray = xr.DataArray(mean, dims=["rows"], coords={"rows": coord_names}) + cov_dataarray = xr.DataArray( + covariance, dims=["rows", "columns"], coords={"rows": coord_names, "columns": coord_names} + ) + + # Create xarray dataset + dataset = xr.Dataset({"mean_vector": mean_dataarray, "covariance_matrix": cov_dataarray}) + + idata.add_groups(fit=dataset) + + return idata + + +def addDataToInferenceData(model, trace, progressbar): + # Add deterministic variables to inference data + trace.posterior = pm.compute_deterministics( + trace.posterior, model=model, merge_dataset=True, progressbar=progressbar + ) + + coords, dims = coords_and_dims_for_inferencedata(model) + + observed_data = dict_to_dataset( + find_observations(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + constant_data = dict_to_dataset( + find_constants(model), + library=pm, + coords=coords, + dims=dims, + default_dims=[], + ) + + trace.add_groups( + {"observed_data": observed_data, "constant_data": constant_data}, + coords=coords, + dims=dims, + ) + + return trace diff --git a/pymc_experimental/tests/test_laplace.py b/pymc_experimental/tests/test_laplace.py new file mode 100644 index 00000000..49e5614b --- /dev/null +++ b/pymc_experimental/tests/test_laplace.py @@ -0,0 +1,137 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +import pymc as pm +import pytest + +import pymc_experimental as pmx + + +@pytest.mark.filterwarnings( + "ignore:hessian will stop negating the output in a future version of PyMC.\n" + + "To suppress this warning set `negate_output=False`:FutureWarning", +) +def test_laplace(): + + # Example originates from Bayesian Data Analyses, 3rd Edition + # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, + # Aki Vehtari, and Donald Rubin. + # See section. 4.1 + + y = np.array([2642, 3503, 4358], dtype=np.float64) + n = y.size + draws = 100000 + + with pm.Model() as m: + logsigma = pm.Uniform("logsigma", 1, 100) + mu = pm.Uniform("mu", -10000, 10000) + yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + vars = [mu, logsigma] + + idata = pmx.fit( + method="laplace", + vars=vars, + model=m, + draws=draws, + random_seed=173300, + ) + + assert idata.posterior["mu"].shape == (1, draws) + assert idata.posterior["logsigma"].shape == (1, draws) + assert idata.observed_data["y"].shape == (n,) + assert idata.fit["mean_vector"].shape == (len(vars),) + assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars)) + + bda_map = [y.mean(), np.log(y.std())] + bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) + + assert np.allclose(idata.fit["mean_vector"].values, bda_map) + assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + + +@pytest.mark.filterwarnings( + "ignore:hessian will stop negating the output in a future version of PyMC.\n" + + "To suppress this warning set `negate_output=False`:FutureWarning", +) +def test_laplace_only_fit(): + + # Example originates from Bayesian Data Analyses, 3rd Edition + # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, + # Aki Vehtari, and Donald Rubin. + # See section. 4.1 + + y = np.array([2642, 3503, 4358], dtype=np.float64) + n = y.size + + with pm.Model() as m: + logsigma = pm.Uniform("logsigma", 1, 100) + mu = pm.Uniform("mu", -10000, 10000) + yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + vars = [mu, logsigma] + + idata = pmx.fit( + method="laplace", + vars=vars, + draws=None, + model=m, + random_seed=173300, + ) + + assert idata.fit["mean_vector"].shape == (len(vars),) + assert idata.fit["covariance_matrix"].shape == (len(vars), len(vars)) + + bda_map = [y.mean(), np.log(y.std())] + bda_cov = np.array([[y.var() / n, 0], [0, 1 / (2 * n)]]) + + assert np.allclose(idata.fit["mean_vector"].values, bda_map) + assert np.allclose(idata.fit["covariance_matrix"].values, bda_cov, atol=1e-4) + + +@pytest.mark.filterwarnings( + "ignore:hessian will stop negating the output in a future version of PyMC.\n" + + "To suppress this warning set `negate_output=False`:FutureWarning", +) +def test_laplace_subset_of_rv(recwarn): + + # Example originates from Bayesian Data Analyses, 3rd Edition + # By Andrew Gelman, John Carlin, Hal Stern, David Dunson, + # Aki Vehtari, and Donald Rubin. + # See section. 4.1 + + y = np.array([2642, 3503, 4358], dtype=np.float64) + n = y.size + + with pm.Model() as m: + logsigma = pm.Uniform("logsigma", 1, 100) + mu = pm.Uniform("mu", -10000, 10000) + yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) + vars = [mu] + + idata = pmx.fit( + method="laplace", + vars=vars, + draws=None, + model=m, + random_seed=173300, + ) + + assert len(recwarn) == 3 + w = recwarn.pop(UserWarning) + assert issubclass(w.category, UserWarning) + assert ( + str(w.message) + == "Number of variables in vars does not eqaul the number of variables in the model." + )