Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
fix reinit_schedulers with correct optimizer (Lightning-AI#5519)
Browse files Browse the repository at this point in the history
this is a cherry pick from master
  • Loading branch information
awaelchli committed Jan 24, 2021
1 parent 5837c45 commit 3090204
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 29 deletions.
24 changes: 0 additions & 24 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,30 +155,6 @@ def test_step_end(self, output):
output = output.mean()
return output

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
"""
Reinitialize optimizer.step properties added by schedulers
"""
for scheduler in schedulers:
scheduler = scheduler['scheduler']

for optimizer in optimizers:
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
is_regular_scheduler = optim.lr_scheduler._LRScheduler
is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau
if is_regular_scheduler or is_lr_reduce_on_plateau:
idx = i
state = scheduler.state_dict()
else:
state = None

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)

def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,21 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
scheduler = scheduler['scheduler']
state = None

for optimizer in optimizers:
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau):
idx = i
state = scheduler.state_dict()
else:
state = None
scheduler.__class__.__mro__[i].__init__(scheduler, optimizer)
scheduler.load_state_dict(state)
break

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)
break


class _MockOptimizer(Optimizer):
Expand Down
11 changes: 11 additions & 0 deletions tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
import torch
from torch import optim

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
Expand Down Expand Up @@ -190,6 +191,13 @@ def test_amp_without_apex(tmpdir):
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
def test_amp_with_apex(tmpdir):
"""Check calling apex scaling in training."""
class CustomModel(EvalModelTemplate):
def configure_optimizers(self):
optimizer1 = optim.Adam(self.parameters(), lr=self.learning_rate)
optimizer2 = optim.SGD(self.parameters(), lr=self.learning_rate)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]

model = BoringModel()

Expand All @@ -204,3 +212,6 @@ def test_amp_with_apex(tmpdir):
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert trainer.dev_debugger.count_events('AMP') == 10

assert isinstance(trainer.lr_schedulers[0]['scheduler'].optimizer, optim.Adam)
assert isinstance(trainer.lr_schedulers[1]['scheduler'].optimizer, optim.SGD)

0 comments on commit 3090204

Please sign in to comment.