Skip to content

Commit

Permalink
Added default parameters to decay rate and steps as well as introduce…
Browse files Browse the repository at this point in the history
…d use lr decay flag
  • Loading branch information
andreasMazur committed Jul 31, 2024
1 parent 8d65119 commit f0bc9bc
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/geoconv_examples/mpi_faust/pytorch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,12 @@ def train_loop(self,
dataset,
loss_fn,
opt,
decay_rate,
decay_steps,
decay_rate=0.95,
decay_steps=500,
verbose=True,
epoch=None,
prev_steps=None):
prev_steps=None,
use_lr_decay=False):
self.train()
epoch_accuracy = 0.
epoch_loss = 0.
Expand All @@ -158,7 +159,8 @@ def train_loop(self,
loss.backward()
opt.step()

custom_exp_scheduler(opt, prev_steps + step, decay_rate=decay_rate, decay_steps=decay_steps)
if use_lr_decay:
custom_exp_scheduler(opt, prev_steps + step, decay_rate=decay_rate, decay_steps=decay_steps)

# Statistics
epoch_accuracy = epoch_accuracy + multiclass_accuracy(pred, gt).detach()
Expand Down

0 comments on commit f0bc9bc

Please sign in to comment.