Skip to content

Commit

Permalink
final cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 12, 2022
1 parent 55c658b commit caa5af1
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 14 deletions.
14 changes: 1 addition & 13 deletions denoising_diffusion_pytorch/learned_gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,6 @@ def __init__(
assert denoise_fn.out_dim == (denoise_fn.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
self.vb_loss_weight = vb_loss_weight

def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Compute the mean and variance of the diffusion posterior q(x_{t-1} | x_t, x_0)
"""
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
model_output = default(model_output, lambda: self.denoise_fn(x, t))
pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)
Expand Down Expand Up @@ -118,7 +106,7 @@ def p_losses(self, x_start, t, noise = None, clip_denoised = False):

# calculating kl loss for learned variance (interpolation)

true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start = x_start, x_t = x_t, t = t)
true_mean, _, true_log_variance_clipped = self.q_posterior(x_start = x_start, x_t = x_t, t = t)
model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)

# kl loss with detached model predicted mean, for stability reasons as in paper
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.14.2',
version = '0.14.3',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit caa5af1

Please sign in to comment.