Skip to content

Commit

Permalink
offer predict_x0 objective
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 13, 2022
1 parent caa5af1 commit 84ebb9a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
30 changes: 23 additions & 7 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,16 @@ def __init__(
image_size,
channels = 3,
timesteps = 1000,
loss_type = 'l1'
loss_type = 'l1',
objective = 'pred_noise'
):
super().__init__()
assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)

self.channels = channels
self.image_size = image_size
self.denoise_fn = denoise_fn
self.objective = objective

betas = cosine_beta_schedule(timesteps)

Expand Down Expand Up @@ -404,12 +406,19 @@ def q_posterior(self, x_start, x_t, t):
return posterior_mean, posterior_variance, posterior_log_variance_clipped

def p_mean_variance(self, x, t, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t=t, noise=self.denoise_fn(x, t))
model_output = self.denoise_fn(x, t)

if self.objective == 'pred_noise':
x_start = self.predict_start_from_noise(x, t = t, noise = model_output)
elif self.objective == 'pred_x0':
x_start = model_output
else:
raise ValueError(f'unknown objective {self.objective}')

if clip_denoised:
x_recon.clamp_(-1., 1.)
x_start.clamp_(-1., 1.)

model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance

@torch.no_grad()
Expand Down Expand Up @@ -475,10 +484,17 @@ def p_losses(self, x_start, t, noise = None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))

x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
x_recon = self.denoise_fn(x_noisy, t)
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.denoise_fn(x, t)

if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
else:
raise ValueError(f'unknown objective {self.objective}')

loss = self.loss_fn(noise, x_recon)
loss = self.loss_fn(model_out, target)
return loss

def forward(self, x, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.14.3',
version = '0.15.0',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 84ebb9a

Please sign in to comment.