diff --git a/pymc_experimental/inference/pathfinder/importance_sampling.py b/pymc_experimental/inference/pathfinder/importance_sampling.py index b45207c3..a7c0785c 100644 --- a/pymc_experimental/inference/pathfinder/importance_sampling.py +++ b/pymc_experimental/inference/pathfinder/importance_sampling.py @@ -8,7 +8,6 @@ import pytensor.tensor as pt from pytensor.graph import Apply, Op -from pytensor.tensor.variable import TensorVariable logger = logging.getLogger(__name__) @@ -34,12 +33,12 @@ def perform(self, node: Apply, inputs, outputs) -> None: def importance_sampling( - samples: TensorVariable, - # logP: TensorVariable, - # logQ: TensorVariable, - logiw: TensorVariable, + samples: np.ndarray, + logP: np.ndarray, + logQ: np.ndarray, num_draws: int, method: Literal["psis", "psir", "identity", "none"], + logiw: np.ndarray | None = None, random_seed: int | None = None, ) -> np.ndarray: """Pareto Smoothed Importance Resampling (PSIR) @@ -79,21 +78,36 @@ def importance_sampling( Zhang, L., Carpenter, B., Gelman, A., & Vehtari, A. (2022). Pathfinder: Parallel quasi-Newton variational inference. Journal of Machine Learning Research, 23(306), 1-49. """ - if method == "psis": - replace = False - logiw, pareto_k = PSIS()(logiw) - elif method == "psir": - replace = True - logiw, pareto_k = PSIS()(logiw) - elif method == "identity": - replace = False - logiw = logiw - pareto_k = None - elif method == "none": + num_paths, num_pdraws, N = samples.shape + + if method == "none": logger.warning( "importance sampling is disabled. The samples are returned as is which may include samples from failed paths with non-finite logP or logQ values. It is recommended to use importance_sampling='psis' for better stability." ) return samples + else: + samples = samples.reshape(-1, N) + logP = logP.ravel() + logQ = logQ.ravel() + + # adjust log densities + log_I = np.log(num_paths) + logP -= log_I + logQ -= log_I + logiw = logP - logQ + + if method == "psis": + replace = False + logiw, pareto_k = PSIS()(logiw) + elif method == "psir": + replace = True + logiw, pareto_k = PSIS()(logiw) + elif method == "identity": + replace = False + logiw = logiw + pareto_k = None + else: + raise ValueError(f"Invalid importance sampling method: {method}") # NOTE: Pareto k is normally bad for Pathfinder even when the posterior is close to the NUTS posterior or closer to NUTS than ADVI. # Pareto k may not be a good diagnostic for Pathfinder. @@ -121,7 +135,7 @@ def importance_sampling( "Consider reparametrising the model all together or ensure the input data are correct." ) - logger.warning(f"Pareto k value: {pareto_k:.2f}") + logger.warning(f"Pareto k value: {pareto_k:.2f}") p = pt.exp(logiw - pt.logsumexp(logiw)).eval() rng = np.random.default_rng(random_seed) diff --git a/pymc_experimental/inference/pathfinder/pathfinder.py b/pymc_experimental/inference/pathfinder/pathfinder.py index 56c68d0d..5e8573cf 100644 --- a/pymc_experimental/inference/pathfinder/pathfinder.py +++ b/pymc_experimental/inference/pathfinder/pathfinder.py @@ -118,7 +118,13 @@ def convert_flat_trace_to_idata( postprocessing_backend="cpu", inference_backend="pymc", model=None, + importance_sampling: Literal["psis", "psir", "identity", "none"] = "psis", ): + if importance_sampling == "none": + # samples.ndim == 3 in this case, otherwise ndim == 2 + num_paths, num_pdraws, N = samples.shape + samples = samples.reshape(-1, N) + model = modelcontext(model) ip = model.initial_point() ip_point_map_info = DictToArrayBijection.map(ip).point_map_info @@ -152,6 +158,10 @@ def convert_flat_trace_to_idata( ) fn.trust_input = True result = fn(*list(trace.values())) + + if importance_sampling == "none": + result = [res.reshape(num_paths, num_pdraws, *res.shape[2:]) for res in result] + elif inference_backend == "blackjax": jax_fn = get_jaxified_graph(inputs=model.value_vars, outputs=vars_to_sample) result = jax.vmap(jax.vmap(jax_fn))( @@ -731,7 +741,6 @@ def multipath_pathfinder( **pathfinder_kwargs, ): *path_seeds, choice_seed = _get_seeds_per_chain(random_seed, num_paths + 1) - N = DictToArrayBijection.map(model.initial_point()).data.shape[0] single_pathfinder_fn = make_single_pathfinder_fn( model, @@ -808,19 +817,11 @@ def multipath_pathfinder( logP = np.concatenate(logP) logQ = np.concatenate(logQ) - samples = samples.reshape(-1, N) - logP = logP.ravel() - logQ = logQ.ravel() - - # adjust log densities - log_I = np.log(num_paths) - logP -= log_I - logQ -= log_I - logiw = logP - logQ - return _importance_sampling( samples=samples, - logiw=logiw, + logP=logP, + logQ=logQ, + # logiw=logiw, num_draws=num_draws, method=importance_sampling, random_seed=choice_seed, @@ -881,7 +882,7 @@ def fit_pathfinder( epsilon: float value used to filter out large changes in the direction of the update gradient at each iteration l in L. Iteration l is only accepted if delta_theta[l] * delta_grad[l] > epsilon * L2_norm(delta_grad[l]) for each l in L. (default is 1e-8). importance_sampling : str, optional - importance sampling method to use. Options are "psis" (default), "psir", "identity", "none. Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size num_draws_per_path * num_paths. + importance sampling method to use which applies sampling based on the log importance weights equal to logP - logQ. Options are "psis" (default), "psir", "identity", "none". Pareto Smoothed Importance Sampling (psis) is recommended in many cases for more stable results than Pareto Smoothed Importance Resampling (psir). identity applies the log importance weights directly without resampling. none applies no importance sampling weights and returns the samples as is of size (num_paths, num_draws_per_path, N) where N is the number of model parameters, otherwise sample size is (num_draws, N). progressbar : bool, optional Whether to display a progress bar (default is False). Setting this to True will likely increase the computation time. random_seed : RandomSeed, optional @@ -974,5 +975,6 @@ def fit_pathfinder( postprocessing_backend=postprocessing_backend, inference_backend=inference_backend, model=model, + importance_sampling=importance_sampling, ) return idata diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index 0331b60c..070e7328 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -45,7 +45,8 @@ def test_pathfinder(inference_backend): with model: idata = pmx.fit( method="pathfinder", - num_paths=20, + num_paths=50, + jitter=10.0, random_seed=41, inference_backend=inference_backend, ) @@ -53,13 +54,13 @@ def test_pathfinder(inference_backend): assert idata.posterior["mu"].shape == (1, 1000) assert idata.posterior["tau"].shape == (1, 1000) assert idata.posterior["theta"].shape == (1, 1000, 8) - # NOTE: Pathfinder tends to return means around 7 and tau around 0.58. So need to increase atol by a large amount. if inference_backend == "pymc": - np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=2.5) - np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=3.8) + np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0, atol=1.6) + np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=1.5) def test_bfgs_sample(): + import pytensor import pytensor.tensor as pt from pymc_experimental.inference.pathfinder.pathfinder import ( @@ -73,6 +74,7 @@ def test_bfgs_sample(): L = Lp1 - 1 J = 6 num_samples = 1000 + rng = pytensor.shared(np.random.default_rng(42), name="rng") # mock data x_data = np.random.randn(Lp1, N) @@ -90,6 +92,7 @@ def test_bfgs_sample(): # sample phi, logq = bfgs_sample( + rng=rng, num_samples=num_samples, x=x, g=g,