Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyMC Implementation of Pathfinder VI #386

Closed
wants to merge 10 commits into from
2 changes: 2 additions & 0 deletions pymc_experimental/inference/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ def fit(method, **kwargs):
arviz.InferenceData
"""
if method == "pathfinder":
# TODO: Remove this once we have a pure PyMC implementation
if find_spec("blackjax") is None:
raise RuntimeError("Need BlackJAX to use `pathfinder`")

from pymc_experimental.inference.pathfinder import fit_pathfinder

# TODO: edit **kwargs to be more consistent with fit_pathfinder with blackjax and pymc backends.
return fit_pathfinder(**kwargs)

if method == "laplace":
Expand Down
134 changes: 0 additions & 134 deletions pymc_experimental/inference/pathfinder.py

This file was deleted.

3 changes: 3 additions & 0 deletions pymc_experimental/inference/pathfinder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from pymc_experimental.inference.pathfinder.pathfinder import fit_pathfinder

__all__ = ["fit_pathfinder"]
73 changes: 73 additions & 0 deletions pymc_experimental/inference/pathfinder/importance_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import logging

import arviz as az
import numpy as np

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)


def psir(
samples: np.ndarray,
logP: np.ndarray,
logQ: np.ndarray,
num_draws: int = 1000,
random_seed: int | None = None,
) -> np.ndarray:
"""Pareto Smoothed Importance Resampling (PSIR)
This implements the Pareto Smooth Importance Resampling (PSIR) method, as described in Algorithm 5 of Zhang et al. (2022). The PSIR follows a similar approach to Algorithm 1 PSIS diagnostic from Yao et al., (2018). However, before computing the the importance ratio r_s, the logP and logQ are adjusted to account for the number multiple estimators (or paths). The process involves resampling from the original sample with replacement, with probabilities proportional to the computed importance weights from PSIS.

Parameters
----------
samples : np.ndarray
samples from proposal distribution
logP : np.ndarray
log probability of target distribution
logQ : np.ndarray
log probability of proposal distribution
num_draws : int
number of draws to return where num_draws <= samples.shape[0]
random_seed : int | None

Returns
-------
np.ndarray
importance sampled draws

Future work!
----------
- Implement the 3 sampling approaches and 5 weighting functions from Elvira et al. (2019)
- Implement Algorithm 2 VSBC marginal diagnostics from Yao et al. (2018)
- Incorporate these various diagnostics, sampling approaches and weighting functions into VI algorithms.

References
----------
Elvira, V., Martino, L., Luengo, D., & Bugallo, M. F. (2019). Generalized Multiple Importance Sampling. Statistical Science, 34(1), 129-155. https://doi.org/10.1214/18-STS668

Yao, Y., Vehtari, A., Simpson, D., & Gelman, A. (2018). Yes, but Did It Work?: Evaluating Variational Inference. arXiv:1802.02538 [Stat]. http://arxiv.org/abs/1802.02538

Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49.
"""

def logsumexp(x):
c = x.max()
return c + np.log(np.sum(np.exp(x - c)))

logiw = np.reshape(logP - logQ, -1, order="F")
psislw, pareto_k = az.psislw(logiw)

# FIXME: pareto_k is mostly bad, find out why!
if pareto_k <= 0.70:
pass
elif 0.70 < pareto_k <= 1:
logger.warning("pareto_k is bad: %f", pareto_k)
logger.info("consider increasing ftol, gtol or maxcor parameters")
else:
logger.warning("pareto_k is very bad: %f", pareto_k)
logger.info(
"consider reparametrising the model, increasing ftol, gtol or maxcor parameters"
)

p = np.exp(psislw - logsumexp(psislw))
rng = np.random.default_rng(random_seed)
return rng.choice(samples, size=num_draws, p=p, shuffle=False, axis=0)
92 changes: 92 additions & 0 deletions pymc_experimental/inference/pathfinder/lbfgs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from collections.abc import Callable
from typing import NamedTuple

import numpy as np

from scipy.optimize import minimize


class LBFGSHistory(NamedTuple):
x: np.ndarray
f: np.ndarray
g: np.ndarray


class LBFGSHistoryManager:
def __init__(self, fn: Callable, grad_fn: Callable, x0: np.ndarray, maxiter: int):
dim = x0.shape[0]
maxiter_add_one = maxiter + 1
# Pre-allocate arrays to save memory and improve speed
self.x_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
self.f_history = np.empty(maxiter_add_one, dtype=np.float64)
self.g_history = np.empty((maxiter_add_one, dim), dtype=np.float64)
self.count = 0
self.fn = fn
self.grad_fn = grad_fn
self.add_entry(x0, fn(x0), grad_fn(x0))

def add_entry(self, x, f, g=None):
self.x_history[self.count] = x
self.f_history[self.count] = f
if self.g_history is not None and g is not None:
self.g_history[self.count] = g
self.count += 1

def get_history(self):
# Return trimmed arrays up to the number of entries actually used
x = self.x_history[: self.count]
f = self.f_history[: self.count]
g = self.g_history[: self.count] if self.g_history is not None else None
return LBFGSHistory(
x=x,
f=f,
g=g,
)

def __call__(self, x):
self.add_entry(x, self.fn(x), self.grad_fn(x))


def lbfgs(
fn,
grad_fn,
x0: np.ndarray,
maxcor: int | None = None,
maxiter=1000,
ftol=1e-5,
gtol=1e-8,
maxls=1000,
**lbfgs_kwargs,
) -> LBFGSHistory:
def callback(xk):
lbfgs_history_manager(xk)

lbfgs_history_manager = LBFGSHistoryManager(
fn=fn,
grad_fn=grad_fn,
x0=x0,
maxiter=maxiter,
)

default_lbfgs_options = dict(
maxcor=maxcor,
maxiter=maxiter,
ftol=ftol,
gtol=gtol,
maxls=maxls,
)
options = lbfgs_kwargs.pop("options", {})
options = default_lbfgs_options | options

# TODO: return the status of the lbfgs optimisation to handle the case where the optimisation fails. More details in the _single_pathfinder function.

minimize(
fn,
x0,
method="L-BFGS-B",
jac=grad_fn,
options=options,
callback=callback,
**lbfgs_kwargs,
)
return lbfgs_history_manager.get_history()
Loading