diff --git a/newt/basemodels.py b/newt/basemodels.py index 17b0492..2a38dc1 100644 --- a/newt/basemodels.py +++ b/newt/basemodels.py @@ -382,7 +382,7 @@ def compute_global_pseudo_lik(self): return pseudo_y_full, pseudo_var_full def compute_full_pseudo_lik(self): - nat1lik_full, nat2lik_full = vmap(self.compute_full_pseudo_nat)(self.obs_ind) + nat1lik_full, nat2lik_full = self.compute_full_pseudo_nat(self.obs_ind) # TODO: remove obs_ind pseudo_var_full = inv_vmap(nat2lik_full + 1e-12 * np.eye(nat2lik_full.shape[1])) pseudo_y_full = pseudo_var_full @ nat1lik_full return pseudo_y_full, pseudo_var_full @@ -391,8 +391,8 @@ def compute_full_pseudo_nat(self, batch_ind): Kuf = self.kernel(self.Z.value, self.X[batch_ind].reshape(-1, 1)) # only compute log lik for observed values Kuu = self.kernel(self.Z.value, self.Z.value) Wuf = solve(Kuu, Kuf) # conditional mapping, Kuu^-1 Kuf - nat1lik_full = Wuf @ self.pseudo_likelihood.nat1[batch_ind].reshape(-1, 1) - nat2lik_full = Wuf @ np.diag(self.pseudo_likelihood.nat2[batch_ind].reshape(-1)) @ transpose(Wuf) + nat1lik_full = Wuf.T[..., None] @ self.pseudo_likelihood.nat1[batch_ind] + nat2lik_full = Wuf.T[..., None] @ self.pseudo_likelihood.nat2[batch_ind] @ Wuf.T[:, None] return nat1lik_full, nat2lik_full def compute_kl(self): @@ -478,6 +478,23 @@ def conditional_posterior_to_data(self, batch_ind=None, post_mean=None, post_cov self.Z.value) return mean_f.reshape(Nbatch, 1, 1), cov_f.reshape(Nbatch, 1, 1) + def cavity_distribution(self, batch_ind=None, power=1.): + """ Compute the power EP cavity for the given data points """ + if batch_ind is None: + batch_ind = np.arange(self.num_data) + + nat1lik_full, nat2lik_full = self.compute_full_pseudo_nat(batch_ind) + + # then compute the cavity + cavity_mean, cavity_cov = vmap(compute_cavity, [None, None, 0, 0, None])( + self.posterior_mean.value[..., 0], + self.posterior_covariance.value, + nat1lik_full, + nat2lik_full, + power + ) + return cavity_mean, cavity_cov + class MarkovGP(BaseModel): """ diff --git a/newt/inference.py b/newt/inference.py index 572fe09..abf5b75 100644 --- a/newt/inference.py +++ b/newt/inference.py @@ -276,7 +276,7 @@ def energy(self, batch_ind=None, cubature=None, power=1.): """ if batch_ind is None: batch_ind = np.arange(self.num_data) - scale = 1 + scale = 1. else: scale = self.num_data / batch_ind.shape[0] @@ -318,7 +318,7 @@ def energy(self, batch_ind=None, cubature=None, power=1.): ep_energy = -( lZ_post - + 1 / power * (scale * np.nansum(lZ) - np.nansum(lZ_pseudo)) + + 1. / power * (scale * np.nansum(lZ) - np.nansum(lZ_pseudo)) ) return ep_energy