Skip to content
This repository has been archived by the owner on Oct 23, 2023. It is now read-only.

Commit

Permalink
undo changes
Browse files Browse the repository at this point in the history
fix


x


x


x


model


test


undo


dp


add todo


fix reinit_schedulers with correct optimizer (Lightning-AI#5519)

this is a cherry pick from master
boring model


lr


try


x


forward call
  • Loading branch information
awaelchli committed Jan 24, 2021
1 parent 48114fa commit 108da0c
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 78 deletions.
24 changes: 0 additions & 24 deletions pytorch_lightning/accelerators/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,30 +155,6 @@ def test_step_end(self, output):
output = output.mean()
return output

def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
"""
Reinitialize optimizer.step properties added by schedulers
"""
for scheduler in schedulers:
scheduler = scheduler['scheduler']

for optimizer in optimizers:
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
is_regular_scheduler = optim.lr_scheduler._LRScheduler
is_lr_reduce_on_plateau = optim.lr_scheduler.ReduceLROnPlateau
if is_regular_scheduler or is_lr_reduce_on_plateau:
idx = i
state = scheduler.state_dict()
else:
state = None

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)

def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
Expand Down
10 changes: 5 additions & 5 deletions pytorch_lightning/trainer/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,21 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
# Reinitialize optimizer.step properties added by schedulers
for scheduler in schedulers:
scheduler = scheduler['scheduler']
state = None

for optimizer in optimizers:
# check that we dont mix users optimizers and schedulers
if scheduler.optimizer == optimizer:
# Find the mro belonging to the base lr scheduler class
for i, mro in enumerate(scheduler.__class__.__mro__):
if mro in (optim.lr_scheduler._LRScheduler, optim.lr_scheduler.ReduceLROnPlateau):
idx = i
state = scheduler.state_dict()
else:
state = None
scheduler.__class__.__mro__[i].__init__(scheduler, optimizer)
scheduler.load_state_dict(state)
break

scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
if state is not None:
scheduler.load_state_dict(state)
break


class _MockOptimizer(Optimizer):
Expand Down
6 changes: 3 additions & 3 deletions tests/base/boring_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def step(self, x):
return out

def training_step(self, batch, batch_idx):
output = self.layer(batch)
output = self(batch)
loss = self.loss(batch, output)
return {"loss": loss}

Expand All @@ -104,15 +104,15 @@ def training_epoch_end(self, outputs) -> None:
torch.stack([x["loss"] for x in outputs]).mean()

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
output = self(batch)
loss = self.loss(batch, output)
return {"x": loss}

def validation_epoch_end(self, outputs) -> None:
torch.stack([x['x'] for x in outputs]).mean()

def test_step(self, batch, batch_idx):
output = self.layer(batch)
output = self(batch)
loss = self.loss(batch, output)
return {"y": loss}

Expand Down
78 changes: 34 additions & 44 deletions tests/base/develop_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import DistributedType
from tests.base import BoringModel
from tests.base.develop_utils import get_default_logger, load_model_from_checkpoint, reset_seed


Expand Down Expand Up @@ -96,54 +95,45 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None,
trainer.checkpoint_connector.hpc_load(checkpoint_path, on_gpu=on_gpu)


def run_prediction(trained_model, dataloader, dp=False, min_acc=0.25):
mode = trained_model.training
ref_model = trained_model.module if dp else trained_model
def run_prediction(trained_model, dataloader, min_acc=0.25):
return _boring_model_run_prediction(trained_model, dataloader, min_acc)

if isinstance(ref_model, BoringModel):
_boring_model_run_prediction(trained_model, dataloader, dp, min_acc)
else:
_eval_model_template_run_prediction(trained_model, dataloader, dp, min_acc)

trained_model.train(mode)


@torch.no_grad()
def _eval_model_template_run_prediction(trained_model, dataloader, dp=False, min_acc=0.50):
# run prediction on 1 batch
batch = next(iter(dataloader))
x, y = batch
x = x.view(x.size(0), -1)

if dp:
output = trained_model(batch, 0)
acc = output['val_acc']
acc = torch.mean(acc).item()

else:
y_hat = trained_model(x).cpu()

# acc
labels_hat = torch.argmax(y_hat, dim=1)

y = y.cpu()
acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
acc = torch.tensor(acc)
acc = acc.item()

assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})"
# def _eval_model_template_run_prediction(trained_model, dataloader, dp=False, min_acc=0.50):
# # run prediction on 1 batch
# batch = next(iter(dataloader))
# x, y = batch
# x = x.view(x.size(0), -1)
#
# if dp:
# with torch.no_grad():
# output = trained_model(batch, 0)
# acc = output['val_acc']
# acc = torch.mean(acc).item()
#
# else:
# with torch.no_grad():
# y_hat = trained_model(x)
# y_hat = y_hat.cpu()
#
# # acc
# labels_hat = torch.argmax(y_hat, dim=1)
#
# y = y.cpu()
# acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
# acc = torch.tensor(acc)
# acc = acc.item()
#
# assert acc >= min_acc, f"This model is expected to get > {min_acc} in test set (it got {acc})"


@torch.no_grad()
def _boring_model_run_prediction(trained_model, dataloader, dp=False, min_acc=0.25):
# TODO: This test compares a loss value with a min accuracy - complete non-sense!
# create BoringModels that make actual predictions!
def _boring_model_run_prediction(trained_model, dataloader, min_acc=0.25):
# run prediction on 1 batch
trained_model.cpu()
batch = next(iter(dataloader))

if dp:
acc = trained_model(batch, 0)['loss']
else:
with torch.no_grad():
output = trained_model(batch)
acc = trained_model.loss(batch, output)

acc = trained_model.loss(batch, output)

assert acc >= min_acc, f"This model is expected to get, {min_acc} in test set but got {acc}"
16 changes: 15 additions & 1 deletion tests/models/test_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import pytest
import torch
from torch import optim

import tests.base.develop_pipelines as tpipes
import tests.base.develop_utils as tutils
Expand Down Expand Up @@ -190,8 +191,18 @@ def test_amp_without_apex(tmpdir):
@pytest.mark.skipif(not _APEX_AVAILABLE, reason="test requires apex")
def test_amp_with_apex(tmpdir):
"""Check calling apex scaling in training."""
class CustomModel(BoringModel):
def configure_optimizers(self):
optimizer1 = optim.Adam(self.parameters(), lr=0.01)
optimizer2 = optim.SGD(self.parameters(), lr=0.01)
lr_scheduler1 = optim.lr_scheduler.StepLR(optimizer1, 1, gamma=0.1)
lr_scheduler2 = optim.lr_scheduler.StepLR(optimizer2, 1, gamma=0.1)
return [optimizer1, optimizer2], [lr_scheduler1, lr_scheduler2]

model = BoringModel()
def training_step(self, batch, batch_idx, optimizer_idx):
return super().training_step(batch, batch_idx)

model = CustomModel()

trainer = Trainer(
default_root_dir=tmpdir,
Expand All @@ -204,3 +215,6 @@ def test_amp_with_apex(tmpdir):
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
assert trainer.dev_debugger.count_events('AMP') == 10

assert isinstance(trainer.lr_schedulers[0]['scheduler'].optimizer, optim.Adam)
assert isinstance(trainer.lr_schedulers[1]['scheduler'].optimizer, optim.SGD)
2 changes: 1 addition & 1 deletion tests/models/test_restore.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def on_train_start(self):
assert self.trainer.current_epoch == real_global_epoch and self.trainer.current_epoch > 0

dataloader = self.train_dataloader()
tpipes.run_prediction(self.trainer.model, dataloader, dp=True)
tpipes.run_prediction(self.trainer.get_model(), dataloader)
self.on_train_start_called = True

# new model
Expand Down

0 comments on commit 108da0c

Please sign in to comment.