Skip to content

Commit

Permalink
Make beta tc loss more stable using torch.logsumexp (from: https://gi…
Browse files Browse the repository at this point in the history
  • Loading branch information
meffmadd committed Nov 11, 2022
1 parent 28d0e95 commit 160c003
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions disent/frameworks/vae/_unsupervised__betatcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _betatc_compute_loss(self, d_posterior: Normal, z_sampled):
def _betatc_compute_total_correlation(z_sampled, z_mean, z_logvar):
"""
Estimate total correlation over a batch.
Reference implementation is from: https://github.com/amir-abdi/disentanglement-pytorch
Reference implementation is from: https://github.com/YannDubs/disentangling-vae
"""
# Compute log(q(z(x_j)|x_i)) for every sample in the batch, which is a
# tensor of size [batch_size, batch_size, num_latents]. In the following
Expand All @@ -103,11 +103,11 @@ def _betatc_compute_total_correlation(z_sampled, z_mean, z_logvar):
# Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(z_j)_l|x_i)))
# + constant) for each sample in the batch, which is a vector of size
# [batch_size,].
log_qz_product = log_qz_prob.exp().sum(dim=1, keepdim=False).log().sum(dim=1, keepdim=False)
log_qz_product = torch.logsumexp(log_qz_prob, dim=1, keepdim=False).sum(dim=1, keepdim=False)

# Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant =
# log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant.
log_qz = log_qz_prob.sum(dim=2, keepdim=False).exp().sum(dim=1, keepdim=False).log()
log_qz = torch.logsumexp(log_qz_prob.sum(dim=2), dim=1, keepdim=False)

return (log_qz - log_qz_product).mean()

Expand Down

0 comments on commit 160c003

Please sign in to comment.