Skip to content

Commit

Permalink
successfully did some basic math and clipped the predicted x0 interme…
Browse files Browse the repository at this point in the history
…diate for the continuous time case
  • Loading branch information
lucidrains committed Jun 9, 2022
1 parent 4284c88 commit 582bfe2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
18 changes: 14 additions & 4 deletions denoising_diffusion_pytorch/continuous_time_gaussian_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(
loss_type = 'l1',
noise_schedule = 'linear',
num_sample_steps = 500,
clip_sample_denoised = True,
learned_schedule_net_hidden_dim = 1024,
learned_noise_schedule_frac_gradient = 1. # between 0 and 1, determines what percentage of gradients go back, so one can update the learned noise schedule more slowly
):
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__(
# sampling

self.num_sample_steps = num_sample_steps
self.clip_sample_denoised = clip_sample_denoised

@property
def device(self):
Expand All @@ -167,20 +169,28 @@ def p_mean_variance(self, x, time, time_next):
# reviewer found an error in the equation in the paper (missing sigma)
# following - https://openreview.net/forum?id=2LdBqxc1Yv&noteId=rIQgH0zKsRt

# todo - derive x_start from the posterior mean and do dynamic thresholding
# assumed that is what is going on in Imagen

log_snr = self.log_snr(time)
log_snr_next = self.log_snr(time_next)
c = -expm1(log_snr - log_snr_next)

squared_alpha, squared_alpha_next = log_snr.sigmoid(), log_snr_next.sigmoid()
squared_sigma, squared_sigma_next = (-log_snr).sigmoid(), (-log_snr_next).sigmoid()

alpha, sigma, alpha_next = map(sqrt, (squared_alpha, squared_sigma, squared_alpha_next))

batch_log_snr = repeat(log_snr, ' -> b', b = x.shape[0])
pred_noise = self.denoise_fn(x, batch_log_snr)

model_mean = sqrt(squared_alpha_next / squared_alpha) * (x - c * sqrt(squared_sigma) * pred_noise)
if self.clip_sample_denoised:
x_start = (x - sigma * pred_noise) / alpha

# in Imagen, this was changed to dynamic thresholding
x_start.clamp_(-1., 1.)

model_mean = alpha_next * (x * (1 - c) / alpha + c * x_start)
else:
model_mean = alpha_next / alpha * (x - c * sigma * pred_noise)

posterior_variance = squared_sigma_next * c

return model_mean, posterior_variance
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.17.4',
version = '0.17.6',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 582bfe2

Please sign in to comment.