diff --git a/pymc_experimental/inference/pathfinder/pathfinder.py b/pymc_experimental/inference/pathfinder/pathfinder.py index 7276d253..9d46361c 100644 --- a/pymc_experimental/inference/pathfinder/pathfinder.py +++ b/pymc_experimental/inference/pathfinder/pathfinder.py @@ -302,7 +302,8 @@ def bfgs_sample( alpha, beta, gamma, - random_seed: RandomSeed | None = None, + # random_seed: RandomSeed | None = None, + rng, ): # batch: L = 8 # alpha_l: (N,) => (L, N) @@ -315,7 +316,7 @@ def bfgs_sample( # logdensity: (M,) => (L, M) # theta: (J, N) - rng = pytensor.shared(np.random.default_rng(seed=random_seed)) + # rng = pytensor.shared(np.random.default_rng(seed=random_seed)) def batched(x, g, alpha, beta, gamma): var_list = [x, g, alpha, beta, gamma] @@ -380,6 +381,64 @@ def compute_logp(logp_func, arr): return np.where(np.isnan(logP), -np.inf, logP) +_x = pt.matrix("_x", dtype="float64") +_g = pt.matrix("_g", dtype="float64") +_alpha = pt.matrix("_alpha", dtype="float64") +_beta = pt.tensor3("_beta", dtype="float64") +_gamma = pt.tensor3("_gamma", dtype="float64") +_epsilon = pt.scalar("_epsilon", dtype="float64") +_maxcor = pt.iscalar("_maxcor") +_alpha, _S, _Z, _update_mask = alpha_recover(_x, _g, epsilon=_epsilon) +_beta, _gamma = inverse_hessian_factors(_alpha, _S, _Z, _update_mask, J=_maxcor) + +_num_elbo_draws = pt.iscalar("_num_elbo_draws") +_dummy_rng = pytensor.shared(np.random.default_rng(), name="_dummy_rng") +_phi, _logQ_phi = bfgs_sample( + num_samples=_num_elbo_draws, + x=_x, + g=_g, + alpha=_alpha, + beta=_beta, + gamma=_gamma, + rng=_dummy_rng, +) + +_num_draws = pt.iscalar("_num_draws") +_x_lstar = pt.dvector("_x_lstar") +_g_lstar = pt.dvector("_g_lstar") +_alpha_lstar = pt.dvector("_alpha_lstar") +_beta_lstar = pt.dmatrix("_beta_lstar") +_gamma_lstar = pt.dmatrix("_gamma_lstar") + + +_psi, _logQ_psi = bfgs_sample( + num_samples=_num_draws, + x=_x_lstar, + g=_g_lstar, + alpha=_alpha_lstar, + beta=_beta_lstar, + gamma=_gamma_lstar, + rng=_dummy_rng, +) + +alpha_recover_compiled = pytensor.function( + inputs=[_x, _g, _epsilon], + outputs=[_alpha, _S, _Z, _update_mask], +) +inverse_hessian_factors_compiled = pytensor.function( + inputs=[_alpha, _S, _Z, _update_mask, _maxcor], + outputs=[_beta, _gamma], +) +bfgs_sample_compiled = pytensor.function( + inputs=[_num_elbo_draws, _x, _g, _alpha, _beta, _gamma], + outputs=[_phi, _logQ_phi], +) +bfgs_sample_lstar_compiled = pytensor.function( + inputs=[_num_draws, _x_lstar, _g_lstar, _alpha_lstar, _beta_lstar, _gamma_lstar], + outputs=[_psi, _logQ_psi], +) + + def single_pathfinder( model, num_draws: int, @@ -423,47 +482,46 @@ def neg_dlogp_func(x): maxls=maxls, ) - # x_full, g_full: (L+1, N) - x_full = pt.as_tensor(lbfgs_history.x, dtype="float64") - g_full = pt.as_tensor(lbfgs_history.g, dtype="float64") + # x, g: (L+1, N) + x = lbfgs_history.x + g = lbfgs_history.g + alpha, S, Z, update_mask = alpha_recover_compiled(x, g, epsilon) + beta, gamma = inverse_hessian_factors_compiled(alpha, S, Z, update_mask, maxcor) # ignore initial point - x, g: (L, N) - x = x_full[1:] - g = g_full[1:] - - alpha, S, Z, update_mask = alpha_recover(x_full, g_full, epsilon=epsilon) - beta, gamma = inverse_hessian_factors(alpha, S, Z, update_mask, J=maxcor) - - phi, logQ_phi = bfgs_sample( - num_samples=num_elbo_draws, - x=x, - g=g, - alpha=alpha, - beta=beta, - gamma=gamma, - random_seed=pathfinder_seed, + x = x[1:] + g = g[1:] + + rng = pytensor.shared(np.random.default_rng(pathfinder_seed), borrow=True) + phi, logQ_phi = bfgs_sample_compiled.copy(swap={_dummy_rng: rng})( + num_elbo_draws, + x, + g, + alpha, + beta, + gamma, ) # .vectorize is slower than apply_along_axis - logP_phi = compute_logp(logp_func, phi.eval()) - logQ_phi = logQ_phi.eval() + logP_phi = compute_logp(logp_func, phi) + # logQ_phi = logQ_phi.eval() elbo = (logP_phi - logQ_phi).mean(axis=-1) lstar = np.argmax(elbo) # BUG: elbo may all be -inf for all l in L. So np.argmax(elbo) will return 0 which is wrong. Still, this won't affect the posterior samples in the multipath Pathfinder scenario because of PSIS/PSIR step. However, the user is left unaware of a failed Pathfinder run. # TODO: handle this case, e.g. by warning of a failed Pathfinder run and skip the following bfgs_sample step to save time. - psi, logQ_psi = bfgs_sample( - num_samples=num_draws, - x=x[lstar], - g=g[lstar], - alpha=alpha[lstar], - beta=beta[lstar], - gamma=gamma[lstar], - random_seed=sample_seed, + rng.set_value(np.random.default_rng(sample_seed), borrow=True) + psi, logQ_psi = bfgs_sample_lstar_compiled.copy(swap={_dummy_rng: rng})( + num_draws, + x[lstar], + g[lstar], + alpha[lstar], + beta[lstar], + gamma[lstar], ) - psi = psi.eval() - logQ_psi = logQ_psi.eval() + # psi = psi.eval() + # logQ_psi = logQ_psi.eval() logP_psi = compute_logp(logp_func, psi) # psi: (1, M, N) # logP_psi: (1, M)