Skip to content

Commit

Permalink
complete the gaussian diffusion with hybrid loss (learned variance) a…
Browse files Browse the repository at this point in the history
…s in the improved ddpm paper
  • Loading branch information
lucidrains committed May 12, 2022
1 parent d412d88 commit 62e8490
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 21 deletions.
29 changes: 20 additions & 9 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def num_to_groups(num, divisor):
arr.append(remainder)
return arr

def normalize_to_neg_one_to_one(img):
return img * 2 - 1

def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5

# small helper modules

class EMA():
Expand Down Expand Up @@ -462,20 +468,23 @@ def q_sample(self, x_start, t, noise=None):
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)

@property
def loss_fn(self):
if self.loss_type == 'l1':
return F.l1_loss
elif self.loss_type == 'l2':
return F.mse_loss
else:
raise ValueError(f'invalid loss type {self.loss_type}')

def p_losses(self, x_start, t, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))

x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
x_recon = self.denoise_fn(x_noisy, t)

if self.loss_type == 'l1':
loss = (noise - x_recon).abs().mean()
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, x_recon)
else:
raise NotImplementedError()

loss = self.loss_fn(noise, x_recon)
return loss

def forward(self, x, *args, **kwargs):
Expand All @@ -498,7 +507,7 @@ def __init__(self, folder, image_size, exts = ['jpg', 'jpeg', 'png']):
transforms.RandomHorizontalFlip(),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1)
transforms.Lambda(normalize_to_neg_one_to_one)
])

def __len__(self):
Expand Down Expand Up @@ -602,11 +611,13 @@ def train(self):
self.step_ema()

if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema_model.eval()

milestone = self.step // self.save_and_sample_every
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = (all_images + 1) * 0.5
all_images = unnormalize_to_zero_to_one(all_images)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = 6)
self.save(milestone)

Expand Down
54 changes: 43 additions & 11 deletions denoising_diffusion_pytorch/learned_gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn, einsum
from einops import rearrange

from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract
from denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, extract, unnormalize_to_zero_to_one

# constants

Expand Down Expand Up @@ -58,17 +58,23 @@ def discretized_gaussian_log_likelihood(x, *, means, log_scales, thres = 0.999):

return log_probs

# gaussian diffusion for learned variance
# https://arxiv.org/abs/2102.09672

# i thought the results were questionable, if one were to focus only on FID
# but may as well get this in here for others to try, as GLIDE is using it (and DALL-E2 first stage of cascade)
# gaussian diffusion for learned variance + hybrid eps simple + vb loss

class LearnedGaussianDiffusion(GaussianDiffusion):
def __init__(
self,
denoise_fn,
vb_loss_weight = 0.001, # lambda was 0.001 in the paper
*args,
**kwargs
):
super().__init__(denoise_fn, *args, **kwargs)
assert denoise_fn.out_dim == (denoise_fn.channels * 2), 'dimension out of unet must be twice the number of channels for learned variance - you can also set the `learned_variance` keyword argument on the Unet to be `True`'
self.vb_loss_weight = vb_loss_weight

def q_posterior_mean_variance(self, x_start, x_t, t):
"""
Expand All @@ -89,25 +95,51 @@ def predict_xstart_from_xprev(self, x_t, t, xprev):
extract(self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape) * x_t
)

def p_mean_variance(self, *, x, t, clip_denoised):
model_output = self.denoise_fn(x, t)
model_output, model_log_variance = model_output.chunk(2, dim = 1)
def p_mean_variance(self, *, x, t, clip_denoised, model_output = None):
model_output = default(model_output, lambda: self.denoise_fn(x, t))
pred_noise, var_interp_frac_unnormalized = model_output.chunk(2, dim = 1)

min_log = extract(self.posterior_log_variance_clipped, t, x.shape)
max_log = extract(torch.log(self.betas), t, x.shape)
var_interp_frac = unnormalize_to_zero_to_one(var_interp_frac_unnormalized)

model_log_variance = var_interp_frac * max_log + (1 - var_interp_frac) * min_log
model_variance = model_log_variance.exp()
return model_output, model_variance, model_log_variance

x_start = self.predict_start_from_noise(x, t, pred_noise)
model_mean, _, _ = self.q_posterior(x_start, x, t)

return model_mean, model_variance, model_log_variance

def p_losses(self, x_start, t, noise = None, clip_denoised = False):
noise = default(noise, lambda: torch.randn_like(x_start))
x_t = self.q_sample(x_start = x_start, t = t, noise = noise)

# model output

model_output = self.denoise_fn(x_t, t)

# calculating kl loss for learned variance (interpolation)

true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start = x_start, x_t = x_t, t = t)
model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised)
model_mean, _, model_log_variance = self.p_mean_variance(x = x_t, t = t, clip_denoised = clip_denoised, model_output = model_output)

# kl loss with detached model predicted mean, for stability reasons as in paper

kl = normal_kl(true_mean, true_log_variance_clipped, model_mean, model_log_variance)
kl = normal_kl(true_mean, true_log_variance_clipped, model_mean.detach(), model_log_variance)
kl = meanflat(kl) * NAT

decoder_nll = -discretized_gaussian_log_likelihood(x_start, means = model_mean, log_scales = 0.5 * model_log_variance)
decoder_nll = meanflat(decoder_nll) * NAT

# At the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
losses = torch.where(t == 0, decoder_nll, kl)
return losses.mean()
# at the first timestep return the decoder NLL, otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))

vb_losses = torch.where(t == 0, decoder_nll, kl)

# simple loss - predicting noise, x0, or x_prev

pred_noise, _ = model_output.chunk(2, dim = 1)

simple_losses = self.loss_fn(pred_noise, noise)

return simple_losses + vb_losses.mean() * self.vb_loss_weight
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.12.1',
version = '0.14.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 62e8490

Please sign in to comment.