Skip to content

Commit

Permalink
use tqdm pbar during training
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 14, 2022
1 parent cf6db71 commit 6012825
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
49 changes: 26 additions & 23 deletions denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,34 +603,37 @@ def load(self, milestone):
self.scaler.load_state_dict(data['scaler'])

def train(self):
while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()
with tqdm(initial = self.step, total = self.train_num_steps) as pbar:

with autocast(enabled = self.amp):
loss = self.model(data)
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
while self.step < self.train_num_steps:
for i in range(self.gradient_accumulate_every):
data = next(self.dl).cuda()

print(f'{self.step}: {loss.item()}')
with autocast(enabled = self.amp):
loss = self.model(data)
self.scaler.scale(loss / self.gradient_accumulate_every).backward()

self.scaler.step(self.opt)
self.scaler.update()
self.opt.zero_grad()
pbar.set_description(f'loss: {loss.item():.4f}')

if self.step % self.update_ema_every == 0:
self.step_ema()
self.scaler.step(self.opt)
self.scaler.update()
self.opt.zero_grad()

if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema_model.eval()
if self.step % self.update_ema_every == 0:
self.step_ema()

milestone = self.step // self.save_and_sample_every
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = unnormalize_to_zero_to_one(all_images)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = 6)
self.save(milestone)
if self.step != 0 and self.step % self.save_and_sample_every == 0:
self.ema_model.eval()

self.step += 1
milestone = self.step // self.save_and_sample_every
batches = num_to_groups(36, self.batch_size)
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
all_images = torch.cat(all_images_list, dim=0)
all_images = unnormalize_to_zero_to_one(all_images)
utils.save_image(all_images, str(self.results_folder / f'sample-{milestone}.png'), nrow = 6)
self.save(milestone)

print('training completed')
self.step += 1
pbar.update(1)

print('training complete')
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'denoising-diffusion-pytorch',
packages = find_packages(),
version = '0.15.2',
version = '0.15.3',
license='MIT',
description = 'Denoising Diffusion Probabilistic Models - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 6012825

Please sign in to comment.