Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
xh-liu committed Dec 20, 2020
1 parent 2400c86 commit 27801f0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
1 change: 0 additions & 1 deletion models/OpenEdit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, opt):

self.generator = opt.netG

self.noise_range = opt.noise_range
if self.perturbation:
self.netP = PerturbationNet(opt)
self.netP.cuda()
Expand Down
32 changes: 32 additions & 0 deletions trainers/OpenEdit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,35 @@ def run_discriminator_one_step(self, data):
d_loss.backward()
self.optimizer_D.step()
self.d_losses = d_losses

def get_latest_losses(self):
return {**self.g_losses, **self.d_losses}

def get_latest_generated(self):
return self.generated

def save(self, epoch):
self.open_edit_model_on_one_gpu.save(epoch)

def update_learning_rate(self, epoch):
if epoch > self.opt.niter:
lrd = self.opt.lr / self.opt.niter_decay
new_lr = self.old_lr - lrd
else:
new_lr = self.old_lr

if new_lr != self.old_lr:
if self.opt.no_TTUR:
new_lr_G = new_lr
new_lr_D = new_lr
else:
new_lr_G = new_lr / 2
new_lr_D = new_lr * 2

if self.loss_gan:
for param_group in self.optimizer_D.param_groups:
param_group['lr'] = new_lr_D
for param_group in self.optimizer_G.param_groups:
param_group['lr'] = new_lr_G
print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
self.old_lr = new_lr

0 comments on commit 27801f0

Please sign in to comment.