Skip to content

Commit

Permalink
fix sparse EP cavity - energy still not working
Browse files Browse the repository at this point in the history
  • Loading branch information
William Wilkinson committed Jun 18, 2021
1 parent a65f15a commit 1d5b2dd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
23 changes: 20 additions & 3 deletions newt/basemodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def compute_global_pseudo_lik(self):
return pseudo_y_full, pseudo_var_full

def compute_full_pseudo_lik(self):
nat1lik_full, nat2lik_full = vmap(self.compute_full_pseudo_nat)(self.obs_ind)
nat1lik_full, nat2lik_full = self.compute_full_pseudo_nat(self.obs_ind) # TODO: remove obs_ind
pseudo_var_full = inv_vmap(nat2lik_full + 1e-12 * np.eye(nat2lik_full.shape[1]))
pseudo_y_full = pseudo_var_full @ nat1lik_full
return pseudo_y_full, pseudo_var_full
Expand All @@ -391,8 +391,8 @@ def compute_full_pseudo_nat(self, batch_ind):
Kuf = self.kernel(self.Z.value, self.X[batch_ind].reshape(-1, 1)) # only compute log lik for observed values
Kuu = self.kernel(self.Z.value, self.Z.value)
Wuf = solve(Kuu, Kuf) # conditional mapping, Kuu^-1 Kuf
nat1lik_full = Wuf @ self.pseudo_likelihood.nat1[batch_ind].reshape(-1, 1)
nat2lik_full = Wuf @ np.diag(self.pseudo_likelihood.nat2[batch_ind].reshape(-1)) @ transpose(Wuf)
nat1lik_full = Wuf.T[..., None] @ self.pseudo_likelihood.nat1[batch_ind]
nat2lik_full = Wuf.T[..., None] @ self.pseudo_likelihood.nat2[batch_ind] @ Wuf.T[:, None]
return nat1lik_full, nat2lik_full

def compute_kl(self):
Expand Down Expand Up @@ -478,6 +478,23 @@ def conditional_posterior_to_data(self, batch_ind=None, post_mean=None, post_cov
self.Z.value)
return mean_f.reshape(Nbatch, 1, 1), cov_f.reshape(Nbatch, 1, 1)

def cavity_distribution(self, batch_ind=None, power=1.):
""" Compute the power EP cavity for the given data points """
if batch_ind is None:
batch_ind = np.arange(self.num_data)

nat1lik_full, nat2lik_full = self.compute_full_pseudo_nat(batch_ind)

# then compute the cavity
cavity_mean, cavity_cov = vmap(compute_cavity, [None, None, 0, 0, None])(
self.posterior_mean.value[..., 0],
self.posterior_covariance.value,
nat1lik_full,
nat2lik_full,
power
)
return cavity_mean, cavity_cov


class MarkovGP(BaseModel):
"""
Expand Down
4 changes: 2 additions & 2 deletions newt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def energy(self, batch_ind=None, cubature=None, power=1.):
"""
if batch_ind is None:
batch_ind = np.arange(self.num_data)
scale = 1
scale = 1.
else:
scale = self.num_data / batch_ind.shape[0]

Expand Down Expand Up @@ -318,7 +318,7 @@ def energy(self, batch_ind=None, cubature=None, power=1.):

ep_energy = -(
lZ_post
+ 1 / power * (scale * np.nansum(lZ) - np.nansum(lZ_pseudo))
+ 1. / power * (scale * np.nansum(lZ) - np.nansum(lZ_pseudo))
)

return ep_energy
Expand Down

0 comments on commit 1d5b2dd

Please sign in to comment.