Skip to content

Commit

Permalink
fix sampling ddpm tqdm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 10, 2022
1 parent 689593a commit f0d59ac
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
8 changes: 4 additions & 4 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def p_sample_loop(self, shape):

x_start = None

for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step'):
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
self_cond = x_start if self.self_condition else None
img, x_start = self.p_sample(img, t, self_cond)

Expand Down Expand Up @@ -599,11 +599,11 @@ def interpolate(self, x1, x2, t = None, lam = 0.5):

assert x1.shape == x2.shape

t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
t_batched = torch.stack([torch.tensor(t, device = device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2))

img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc='interpolation sample time step', total=t):
for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))

return img
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.27.0',
version = '0.27.1',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit f0d59ac

Please sign in to comment.