From c10e21144582ad9e845791e28e0d991a0d29ed33 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 27 Apr 2021 17:38:27 +0100 Subject: [PATCH 01/26] update --- .../connectors/checkpoint_connector.py | 2 + tests/callbacks/test_finetuning_callback.py | 53 ++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 4ae42e4bad6ac..536cc489c2d46 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -179,6 +179,8 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True) # restore the optimizers optimizer_states = checkpoint['optimizer_states'] + import pdb + pdb.set_trace() for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index c11d58cb18543..42b9b8c05d085 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -20,7 +20,7 @@ from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, seed_everything, Trainer -from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning +from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint from pytorch_lightning.callbacks.base import Callback from tests.helpers import BoringModel, RandomDataset @@ -123,6 +123,8 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): pl_module.backbone, optimizer, 0.1, train_bn=self.train_bn, initial_denom_lr=self.initial_denom_lr ) + chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) + model = FinetuningBoringModel() model.validation_step = None callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False) @@ -131,12 +133,14 @@ def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): trainer = Trainer( limit_train_batches=1, default_root_dir=tmpdir, - callbacks=[callback], + callbacks=[callback, chk], max_epochs=2, ) trainer.fit(model) assert model.backbone.has_been_used + trainer = Trainer(max_epochs=3, resume_from_checkpoint=chk.last_model_path) + trainer.fit(model) def test_freeze_unfreeze_function(tmpdir): @@ -283,3 +287,48 @@ def forward(self, x): # conv0.weight, conv0.bias, bn0.weight, bn0.bias # conv1.weight, conv1.bias, bn1.weight, bn1.bias assert len(encoder_params) == 8 + + +def test_bolts(tmpdir): + + from pl_bolts.datamodules import SklearnDataModule + from pl_bolts.models.regression import LinearRegression + from sklearn.datasets import load_boston + + import pytorch_lightning as pl + from pytorch_lightning.callbacks import ModelCheckpoint + from pytorch_lightning.callbacks.finetuning import BaseFinetuning + + X, y = load_boston(return_X_y=True) + dm = SklearnDataModule(X, y) + + class TmpLinearRegression(LinearRegression): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.to_freeze = nn.Linear(1, 1) + + def configure_optimizers(self): + trainable_parameters = list(filter(lambda p: p.requires_grad, self.parameters())) + return self.optimizer(trainable_parameters, lr=self.hparams.learning_rate) + + class LRFinetuner(BaseFinetuning): + + def freeze_before_training(self, pl_module): + self.freeze(modules=pl_module.to_freeze) + + def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): + if current_epoch == 1: + self.unfreeze_and_add_param_group( + modules=pl_module.to_freeze, + optimizer=optimizer, + initial_denom_lr=10, + ) + + model = TmpLinearRegression(input_dim=13) + trainer = pl.Trainer(max_epochs=2, callbacks=[ModelCheckpoint(dirpath="./tmp", save_last=True), LRFinetuner()]) + trainer.fit(model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader()) + + model = TmpLinearRegression(input_dim=13) + trainer = pl.Trainer(max_epochs=3, resume_from_checkpoint="tmp/last.ckpt") + trainer.fit(model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader()) From 82373f50f5df2503a290c5b233f72b7e7b95d870 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 28 Apr 2021 12:00:55 +0100 Subject: [PATCH 02/26] wip --- pytorch_lightning/callbacks/finetuning.py | 42 ++++++++++++++-- tests/callbacks/test_finetuning_callback.py | 55 ++++++++++++++++++++- 2 files changed, 92 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index ea508775d126f..34761db36e768 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -17,13 +17,14 @@ Freeze and unfreeze models for finetuning purposes """ import logging -from typing import Callable, Generator, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch from torch.nn import Module from torch.nn.modules.batchnorm import _BatchNorm from torch.optim.optimizer import Optimizer +import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn @@ -68,7 +69,7 @@ def __init__(self, unfreeze_at_epoch=10) self._unfreeze_at_epoch = unfreeze_at_epoch def freeze_before_training(self, pl_module): - # freeze any module you want + # freeze any module you want # Here, we are freezing ``feature_extractor`` self.freeze(pl_module.feature_extractor) @@ -82,6 +83,18 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): ) """ + def __init__(self): + self._internal_state = {} + + def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: LightningModule, + checkpoint: Dict[str, Any]) -> Dict[str, Any]: + return self._internal_state + + def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + self._internal_state = callback_state + import pdb + pdb.set_trace() + @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: """ @@ -234,10 +247,33 @@ def unfreeze_and_add_param_group( def on_before_accelerator_backend_setup(self, trainer, pl_module): self.freeze_before_training(pl_module) + def _add_to_internal_state( + self, pl_module: LightningModule, opt_idx: int, current_param_groups: List[Dict[str, Any]] + ) -> None: + map_p_to_name = {p: n for n, p in pl_module.named_parameters()} + for g in current_param_groups: + group_state = {k: v for k, v in g.items() if k != 'params'} + group_state['params'] = [map_p_to_name[p] for p in g['params']] + self._internal_state[opt_idx].append(group_state) + + def _store( + self, pl_module: LightningModule, opt_idx: int, len_previous_param_groups: List[Dict[str, Any]], + current_param_groups: List[Dict[str, Any]] + ) -> None: + if opt_idx not in self._internal_state: + self._internal_state[opt_idx] = [] + self._add_to_internal_state(pl_module, opt_idx, current_param_groups) + + elif len_previous_param_groups != len(current_param_groups): + self._add_to_internal_state(pl_module, opt_idx, current_param_groups[len_previous_param_groups:]) + def on_train_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" for opt_idx, optimizer in trainer.train_loop.prepare_optimizers(): + len_previous_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) + current_param_groups = optimizer.param_groups + self._store(pl_module, opt_idx, len_previous_param_groups, current_param_groups) def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): """ @@ -305,6 +341,7 @@ def __init__( verbose: bool = False, round: int = 12, ): + super().__init__() self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch self.backbone_initial_lr = backbone_initial_lr self.lambda_func = lambda_func @@ -330,7 +367,6 @@ def freeze_before_training(self, pl_module: LightningModule): def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): """Called when the epoch begins.""" - if epoch == self.unfreeze_backbone_at_epoch: current_lr = optimizer.param_groups[0]['lr'] initial_backbone_lr = self.backbone_initial_lr if self.backbone_initial_lr is not None \ diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 42b9b8c05d085..315487d81ba90 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -16,7 +16,7 @@ import pytest import torch from torch import nn -from torch.optim import SGD +from torch.optim import Optimizer, SGD from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, seed_everything, Trainer @@ -73,7 +73,7 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False) trainer = Trainer( - limit_train_batches=1, + limit_train_batches=4, default_root_dir=tmpdir, callbacks=[callback], max_epochs=8, @@ -224,6 +224,57 @@ def __init__(self): assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight) +class OnEpochLayerFinetuning(BaseFinetuning): + + def freeze_before_training(self, pl_module: LightningModule): + self.freeze(pl_module.layer) + + def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + self.unfreeze_and_add_param_group(pl_module.layer[epoch + 1], optimizer) + + +def test_base_finetuning_internal_state(tmpdir): + """Test the param_groups updates are properly saved within the internal state of the BaseFinetuning Callbacks""" + + seed_everything(42) + + class FreezeModel(BoringModel): + + def __init__(self): + super().__init__() + self.layer = nn.Sequential( + nn.Linear(32, 32, bias=False), + nn.Linear(32, 32, bias=True), + nn.Linear(32, 32, bias=False), + nn.Linear(32, 32, bias=True), + nn.Linear(32, 32, bias=False), + nn.Linear(32, 2, bias=True), + ) + + def forward(self, x): + return self.layer(x) + + def configure_optimizers(self): + return torch.optim.SGD(self.layer[0].parameters(), lr=0.1) + + cb = OnEpochLayerFinetuning() + chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) + model = FreezeModel() + trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, limit_train_batches=1, callbacks=[cb, chk]) + trainer.fit(model) + assert len(cb._internal_state[0]) == 6 + assert cb._internal_state[0][0]["params"] == ['layer.0.weight'] + assert cb._internal_state[0][1]["params"] == ['layer.1.weight', 'layer.1.bias'] + assert cb._internal_state[0][2]["params"] == ['layer.2.weight'] + assert cb._internal_state[0][3]["params"] == ['layer.3.weight', 'layer.3.bias'] + assert cb._internal_state[0][4]["params"] == ['layer.4.weight'] + assert cb._internal_state[0][5]["params"] == ['layer.5.weight', 'layer.5.bias'] + + model = FreezeModel() + trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path) + trainer.fit(model) + + def test_on_before_accelerator_backend_setup(tmpdir): """ `on_before_accelerator_backend_setup` hook is used by finetuning callbacks to freeze the model before From ab9b2ffb9d514947b4bd32d8d7d86a8a4913af9b Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 28 Apr 2021 14:22:28 +0100 Subject: [PATCH 03/26] udpate --- pytorch_lightning/callbacks/finetuning.py | 20 +++++++++++++++--- pytorch_lightning/trainer/callback_hook.py | 21 ++++++++++++++++--- .../connectors/checkpoint_connector.py | 2 -- tests/callbacks/test_finetuning_callback.py | 6 ++++-- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 34761db36e768..3ef010d15724f 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -17,6 +17,7 @@ Freeze and unfreeze models for finetuning purposes """ import logging +from copy import deepcopy from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch @@ -90,10 +91,23 @@ def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: LightningModule, checkpoint: Dict[str, Any]) -> Dict[str, Any]: return self._internal_state - def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + def on_load_checkpoint( + self, trainer: 'pl.Trainer', pl_module: LightningModule, callback_state: Dict[str, Any] + ) -> None: self._internal_state = callback_state - import pdb - pdb.set_trace() + # restore the param_groups created during training. + map_name_to_p = {n: p for n, p in pl_module.named_parameters()} + for opt_idx, optimizer in enumerate(trainer.optimizers): + param_groups = self._param_groups_state_to_param_groups( + deepcopy(self._internal_state[opt_idx]), map_name_to_p + ) + optimizer.param_groups = param_groups + + def _param_groups_state_to_param_groups(self, param_groups_state: Dict[str, Any], + map_name_to_p: Dict[str, ]) -> Dict[str, Any]: + for group in param_groups_state: + group["params"] = [map_name_to_p[name] for name in group["params"]] + return param_groups_state @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 4473bec5f026f..5ce53816b97ce 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -268,17 +268,24 @@ def on_keyboard_interrupt(self): callback.on_keyboard_interrupt(self, self.lightning_module) @staticmethod - def __is_old_signature(fn: Callable) -> bool: + def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) if len(parameters) == 2 and parameters[1] != "args": return True return False + @staticmethod + def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: + parameters = list(signature(fn).parameters) + if len(parameters) == 1: + return True + return False + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: """Called when saving a model checkpoint.""" callback_states = {} for callback in self.callbacks: - if self.__is_old_signature(callback.on_save_checkpoint): + if self.__is_old_signature_on_save_checkpoint(callback.on_save_checkpoint): rank_zero_deprecation( "`Callback.on_save_checkpoint` signature has changed in v1.3." " A `checkpoint` parameter has been added." @@ -302,7 +309,15 @@ def on_load_checkpoint(self, checkpoint): state = callback_states.get(type(callback)) if state: state = deepcopy(state) - callback.on_load_checkpoint(state) + if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): + rank_zero_deprecation( + "`Callback.on_load_checkpoint` signature has changed in v1.3." + " `Trainer` and `LightningModule` parameter are been added." + " Support for the old signature will be removed in v1.5" + ) + state = callback.on_load_checkpoint(state) # noqa: parameter-unfilled + else: + state = callback.on_load_checkpoint(self, self.lightning_module, state) def on_after_backward(self): """ diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 536cc489c2d46..4ae42e4bad6ac 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -179,8 +179,6 @@ def restore_training_state(self, checkpoint, load_optimizer_states: bool = True) # restore the optimizers optimizer_states = checkpoint['optimizer_states'] - import pdb - pdb.set_trace() for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states): optimizer.load_state_dict(opt_state) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 315487d81ba90..4fe258e09d6d4 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -271,8 +271,10 @@ def configure_optimizers(self): assert cb._internal_state[0][5]["params"] == ['layer.5.weight', 'layer.5.bias'] model = FreezeModel() - trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path) - trainer.fit(model) + cb = OnEpochLayerFinetuning() + trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb]) + with pytest.raises(IndexError, match="index 6 is out of range"): + trainer.fit(model) def test_on_before_accelerator_backend_setup(tmpdir): From 4682cdec2bdc7bd37680fbc06f15734fa0932c2a Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 28 Apr 2021 14:23:26 +0100 Subject: [PATCH 04/26] update --- pytorch_lightning/callbacks/finetuning.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 3ef010d15724f..d434b77104741 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -96,15 +96,16 @@ def on_load_checkpoint( ) -> None: self._internal_state = callback_state # restore the param_groups created during training. - map_name_to_p = {n: p for n, p in pl_module.named_parameters()} + _map_name_to_p = {n: p for n, p in pl_module.named_parameters()} for opt_idx, optimizer in enumerate(trainer.optimizers): param_groups = self._param_groups_state_to_param_groups( - deepcopy(self._internal_state[opt_idx]), map_name_to_p + deepcopy(self._internal_state[opt_idx]), _map_name_to_p ) optimizer.param_groups = param_groups - def _param_groups_state_to_param_groups(self, param_groups_state: Dict[str, Any], - map_name_to_p: Dict[str, ]) -> Dict[str, Any]: + def _param_groups_state_to_param_groups( + self, param_groups_state: Dict[str, Any], map_name_to_p: Dict[str, torch.Tensor] + ) -> Dict[str, Any]: for group in param_groups_state: group["params"] = [map_name_to_p[name] for name in group["params"]] return param_groups_state From 400a591b4ac1685c6cd3e417616e8831250a735f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 28 Apr 2021 14:28:48 +0100 Subject: [PATCH 05/26] update --- CHANGELOG.md | 3 ++ tests/callbacks/test_finetuning_callback.py | 45 --------------------- 2 files changed, 3 insertions(+), 45 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 427b5ad9fd3d7..c9aa76cbb7126 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -364,6 +364,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed metrics not being properly logged with `precision=16` and `manual_optimization` ([#7228](https://github.com/PyTorchLightning/pytorch-lightning/pull/7228)) +- Fixed `BaseFinetuning` properly reloading `optimizer_states` when using `resume_from_checkpoint` ([#6891](https://github.com/PyTorchLightning/pytorch-lightning/pull/6891)) + + ## [1.2.7] - 2021-04-06 ### Fixed diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 4fe258e09d6d4..265d91c2faca7 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -340,48 +340,3 @@ def forward(self, x): # conv0.weight, conv0.bias, bn0.weight, bn0.bias # conv1.weight, conv1.bias, bn1.weight, bn1.bias assert len(encoder_params) == 8 - - -def test_bolts(tmpdir): - - from pl_bolts.datamodules import SklearnDataModule - from pl_bolts.models.regression import LinearRegression - from sklearn.datasets import load_boston - - import pytorch_lightning as pl - from pytorch_lightning.callbacks import ModelCheckpoint - from pytorch_lightning.callbacks.finetuning import BaseFinetuning - - X, y = load_boston(return_X_y=True) - dm = SklearnDataModule(X, y) - - class TmpLinearRegression(LinearRegression): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.to_freeze = nn.Linear(1, 1) - - def configure_optimizers(self): - trainable_parameters = list(filter(lambda p: p.requires_grad, self.parameters())) - return self.optimizer(trainable_parameters, lr=self.hparams.learning_rate) - - class LRFinetuner(BaseFinetuning): - - def freeze_before_training(self, pl_module): - self.freeze(modules=pl_module.to_freeze) - - def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): - if current_epoch == 1: - self.unfreeze_and_add_param_group( - modules=pl_module.to_freeze, - optimizer=optimizer, - initial_denom_lr=10, - ) - - model = TmpLinearRegression(input_dim=13) - trainer = pl.Trainer(max_epochs=2, callbacks=[ModelCheckpoint(dirpath="./tmp", save_last=True), LRFinetuner()]) - trainer.fit(model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader()) - - model = TmpLinearRegression(input_dim=13) - trainer = pl.Trainer(max_epochs=3, resume_from_checkpoint="tmp/last.ckpt") - trainer.fit(model, train_dataloader=dm.train_dataloader(), val_dataloaders=dm.val_dataloader()) From 20ac64dc71b6aaf7feb408e930a28b93d066a720 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 28 Apr 2021 14:50:21 +0100 Subject: [PATCH 06/26] update --- pytorch_lightning/callbacks/base.py | 6 +- pytorch_lightning/trainer/callback_hook.py | 24 ++++--- tests/deprecated_api/test_remove_1-5.py | 75 ++++++++++++++++++++++ 3 files changed, 91 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a7a65bd8cac55..d4cad4a3b1234 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -273,10 +273,14 @@ def on_save_checkpoint( """ pass - def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None: + def on_load_checkpoint( + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', callback_state: Dict[str, Any] + ) -> None: """Called when loading a model checkpoint, use to reload state. Args: + trainer: the current Trainer instance. + pl_module: the current 'pl.LightningModule' instance. callback_state: the callback state returned by ``on_save_checkpoint``. """ pass diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 5ce53816b97ce..ac6d446706bc2 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -277,7 +277,7 @@ def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: @staticmethod def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) - if len(parameters) == 1: + if len(parameters) == 1 and parameters[0] == "callback_state": return True return False @@ -306,18 +306,16 @@ def on_load_checkpoint(self, checkpoint): # https://github.com/pytorch/xla/issues/2773 if callback_states is not None: for callback in self.callbacks: - state = callback_states.get(type(callback)) - if state: - state = deepcopy(state) - if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): - rank_zero_deprecation( - "`Callback.on_load_checkpoint` signature has changed in v1.3." - " `Trainer` and `LightningModule` parameter are been added." - " Support for the old signature will be removed in v1.5" - ) - state = callback.on_load_checkpoint(state) # noqa: parameter-unfilled - else: - state = callback.on_load_checkpoint(self, self.lightning_module, state) + state = deepcopy(callback_states.get(type(callback))) + if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): + rank_zero_deprecation( + "`Callback.on_load_checkpoint` signature has changed in v1.3." + " `Trainer` and `LightningModule` parameter are been added." + " Support for the old signature will be removed in v1.5" + ) + state = callback.on_load_checkpoint(state) # noqa: parameter-unfilled + else: + state = callback.on_load_checkpoint(self, self.lightning_module, state) def on_after_backward(self): """ diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 6516fbcc18639..530d1f44df1e2 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -93,6 +93,81 @@ def on_save_checkpoint(self, *args): trainer.save_checkpoint(filepath) +class BaseSignatureOnLoadCheckpoint(Callback): + + def __init__(self): + self.on_load_checkpoint_called = False + + +class NewSignatureOnLoadCheckpoint(BaseSignatureOnLoadCheckpoint): + + def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict: + return {"something": "something"} + + def on_load_checkpoint(self, trainer, pl_module, checkpoint): + assert checkpoint == {"something": "something"} + self.on_load_checkpoint_called = True + + +def test_v1_5_0_old_callback_on_load_checkpoint(tmpdir): + + class OldSignature(BaseSignatureOnLoadCheckpoint): + + def on_load_checkpoint(self, callback_state) -> None: + assert callback_state is None + self.on_load_checkpoint_called = True + + model = BoringModel() + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 3, + } + chk = ModelCheckpoint(save_last=True) + trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature(), chk]) + trainer.fit(model) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer_kwargs["max_epochs"] = 5 + cb = OldSignature() + trainer = Trainer(**trainer_kwargs, callbacks=cb, resume_from_checkpoint=chk.last_model_path) + trainer.fit(model) + assert cb.on_load_checkpoint_called + + class ValidSignature1(BaseSignatureOnLoadCheckpoint): + + def on_load_checkpoint(self, trainer, *args): + assert len(args) == 2 + self.on_load_checkpoint_called = True + + class ValidSignature2(BaseSignatureOnLoadCheckpoint): + + def on_load_checkpoint(self, *args): + assert len(args) == 3 + self.on_load_checkpoint_called = True + + model = BoringModel() + trainer_kwargs = { + "default_root_dir": tmpdir, + "max_epochs": 3, + } + chk = ModelCheckpoint(save_last=True) + trainer = Trainer( + **trainer_kwargs, callbacks=[NewSignatureOnLoadCheckpoint(), + ValidSignature1(), + ValidSignature2(), chk] + ) + trainer.fit(model) + + with pytest.deprecated_call(match="old signature will be removed in v1.5"): + trainer_kwargs["max_epochs"] = 5 + cb, cb_1, cb_2 = NewSignatureOnLoadCheckpoint(), ValidSignature1(), ValidSignature2() + trainer = Trainer(**trainer_kwargs, callbacks=[cb, cb_1, cb_2], resume_from_checkpoint=chk.last_model_path) + trainer.fit(model) + assert cb.on_load_checkpoint_called + assert cb_1.on_load_checkpoint_called + assert cb_2.on_load_checkpoint_called + + def test_v1_5_0_legacy_profiler_argument(): with pytest.deprecated_call(match="renamed to `record_functions` in v1.3"): PyTorchProfiler(profiled_functions=[]) From 92c55cf77dd0b75bbf9ef30674d38314ff93889f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 28 Apr 2021 16:31:57 +0100 Subject: [PATCH 07/26] resolve bug --- tests/callbacks/test_finetuning_callback.py | 52 +++++++++++---------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 265d91c2faca7..971172577e421 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -25,6 +25,20 @@ from tests.helpers import BoringModel, RandomDataset +class TestTestBackboneFinetuningCallbackCallback(BackboneFinetuning): + + def on_train_epoch_end(self, trainer, pl_module, outputs): + epoch = trainer.current_epoch + if self.unfreeze_backbone_at_epoch <= epoch: + optimizer = trainer.optimizers[0] + current_lr = optimizer.param_groups[0]['lr'] + backbone_lr = self.previous_backbone_lr + if epoch < 6: + assert backbone_lr <= current_lr + else: + assert backbone_lr == current_lr + + def test_finetuning_callback(tmpdir): """Test finetuning callbacks works as expected""" @@ -56,21 +70,8 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=2) - class TestCallback(BackboneFinetuning): - - def on_train_epoch_end(self, trainer, pl_module, outputs): - epoch = trainer.current_epoch - if self.unfreeze_backbone_at_epoch <= epoch: - optimizer = trainer.optimizers[0] - current_lr = optimizer.param_groups[0]['lr'] - backbone_lr = self.previous_backbone_lr - if epoch < 6: - assert backbone_lr <= current_lr - else: - assert backbone_lr == current_lr - model = FinetuningBoringModel() - callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False) + callback = TestTestBackboneFinetuningCallbackCallback(unfreeze_backbone_at_epoch=3, verbose=False) trainer = Trainer( limit_train_batches=4, @@ -83,6 +84,17 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): assert model.backbone.has_been_used +class TestBackboneFinetuningWarningCallback(BackboneFinetuning): + + def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): + """Called when the epoch begins.""" + + if epoch == 0: + self.unfreeze_and_add_param_group( + pl_module.backbone, optimizer, 0.1, train_bn=self.train_bn, initial_denom_lr=self.initial_denom_lr + ) + + def test_finetuning_callback_warning(tmpdir): """Test finetuning callbacks works as expected""" @@ -113,21 +125,11 @@ def configure_optimizers(self): optimizer = torch.optim.SGD(self.parameters(), lr=0.1) return optimizer - class TestCallback(BackboneFinetuning): - - def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int): - """Called when the epoch begins.""" - - if epoch == 0: - self.unfreeze_and_add_param_group( - pl_module.backbone, optimizer, 0.1, train_bn=self.train_bn, initial_denom_lr=self.initial_denom_lr - ) - chk = ModelCheckpoint(dirpath=tmpdir, save_last=True) model = FinetuningBoringModel() model.validation_step = None - callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False) + callback = TestBackboneFinetuningWarningCallback(unfreeze_backbone_at_epoch=3, verbose=False) with pytest.warns(UserWarning, match="Did you init your optimizer in"): trainer = Trainer( From c2ba01887ccc4df08cb5c596d0cba7da59defd82 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 09:18:11 +0100 Subject: [PATCH 08/26] update on comments --- pytorch_lightning/trainer/callback_hook.py | 19 ++++----- tests/deprecated_api/test_remove_1-5.py | 46 +++++++++++++--------- 2 files changed, 38 insertions(+), 27 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index ac6d446706bc2..684b43aec9002 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -307,15 +307,16 @@ def on_load_checkpoint(self, checkpoint): if callback_states is not None: for callback in self.callbacks: state = deepcopy(callback_states.get(type(callback))) - if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): - rank_zero_deprecation( - "`Callback.on_load_checkpoint` signature has changed in v1.3." - " `Trainer` and `LightningModule` parameter are been added." - " Support for the old signature will be removed in v1.5" - ) - state = callback.on_load_checkpoint(state) # noqa: parameter-unfilled - else: - state = callback.on_load_checkpoint(self, self.lightning_module, state) + if state is not None: + if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): + rank_zero_deprecation( + "`Callback.on_load_checkpoint` signature has changed in v1.3." + " `Trainer` and `LightningModule` parameter have been added." + " Support for the old signature will be removed in v1.5" + ) + state = callback.on_load_checkpoint(state) # noqa: parameter-unfilled + else: + state = callback.on_load_checkpoint(self, self.lightning_module, state) def on_after_backward(self): """ diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 530d1f44df1e2..fd8ddf45b8e1b 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -13,6 +13,7 @@ # limitations under the License. """Test deprecated functionality which will be removed in v1.5.0""" import os +from typing import Any, Dict from unittest import mock import pytest @@ -99,6 +100,16 @@ def __init__(self): self.on_load_checkpoint_called = False +class OldSignatureOnLoadCheckpoint(BaseSignatureOnLoadCheckpoint): + + def on_save_checkpoint(self, *args) -> Dict[str, Any]: + return {"a": 0} + + def on_load_checkpoint(self, callback_state) -> None: + assert callback_state == {"a": 0} + self.on_load_checkpoint_called = True + + class NewSignatureOnLoadCheckpoint(BaseSignatureOnLoadCheckpoint): def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict: @@ -109,13 +120,17 @@ def on_load_checkpoint(self, trainer, pl_module, checkpoint): self.on_load_checkpoint_called = True -def test_v1_5_0_old_callback_on_load_checkpoint(tmpdir): +class ValidSignature2OnLoadCheckpoint(BaseSignatureOnLoadCheckpoint): + + def on_save_checkpoint(self, trainer, pl_module, checkpoint) -> dict: + return {"something": "something"} + + def on_load_checkpoint(self, *args): + assert len(args) == 3 + self.on_load_checkpoint_called = True - class OldSignature(BaseSignatureOnLoadCheckpoint): - def on_load_checkpoint(self, callback_state) -> None: - assert callback_state is None - self.on_load_checkpoint_called = True +def test_v1_5_0_old_callback_on_load_checkpoint(tmpdir): model = BoringModel() trainer_kwargs = { @@ -123,12 +138,12 @@ def on_load_checkpoint(self, callback_state) -> None: "max_epochs": 3, } chk = ModelCheckpoint(save_last=True) - trainer = Trainer(**trainer_kwargs, callbacks=[OldSignature(), chk]) + trainer = Trainer(**trainer_kwargs, callbacks=[OldSignatureOnLoadCheckpoint(), chk]) trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer_kwargs["max_epochs"] = 5 - cb = OldSignature() + cb = OldSignatureOnLoadCheckpoint() trainer = Trainer(**trainer_kwargs, callbacks=cb, resume_from_checkpoint=chk.last_model_path) trainer.fit(model) assert cb.on_load_checkpoint_called @@ -139,12 +154,6 @@ def on_load_checkpoint(self, trainer, *args): assert len(args) == 2 self.on_load_checkpoint_called = True - class ValidSignature2(BaseSignatureOnLoadCheckpoint): - - def on_load_checkpoint(self, *args): - assert len(args) == 3 - self.on_load_checkpoint_called = True - model = BoringModel() trainer_kwargs = { "default_root_dir": tmpdir, @@ -152,19 +161,20 @@ def on_load_checkpoint(self, *args): } chk = ModelCheckpoint(save_last=True) trainer = Trainer( - **trainer_kwargs, callbacks=[NewSignatureOnLoadCheckpoint(), - ValidSignature1(), - ValidSignature2(), chk] + **trainer_kwargs, + callbacks=[NewSignatureOnLoadCheckpoint(), + ValidSignature1(), + ValidSignature2OnLoadCheckpoint(), chk] ) trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): trainer_kwargs["max_epochs"] = 5 - cb, cb_1, cb_2 = NewSignatureOnLoadCheckpoint(), ValidSignature1(), ValidSignature2() + cb, cb_1, cb_2 = NewSignatureOnLoadCheckpoint(), ValidSignature1(), ValidSignature2OnLoadCheckpoint() trainer = Trainer(**trainer_kwargs, callbacks=[cb, cb_1, cb_2], resume_from_checkpoint=chk.last_model_path) trainer.fit(model) assert cb.on_load_checkpoint_called - assert cb_1.on_load_checkpoint_called + assert not cb_1.on_load_checkpoint_called assert cb_2.on_load_checkpoint_called From 3996cd715e9baccc81e2a59448444f9fca544088 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 09:20:46 +0100 Subject: [PATCH 09/26] update on comments --- pytorch_lightning/callbacks/base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d4cad4a3b1234..c6650c3b49375 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -278,6 +278,11 @@ def on_load_checkpoint( ) -> None: """Called when loading a model checkpoint, use to reload state. + .. note:: + + The ``on_load_checkpoint`` will be called only if a state is provided. + Therefore, ``on_save_checkpoint`` hook need to be overridden to return a state. + Args: trainer: the current Trainer instance. pl_module: the current 'pl.LightningModule' instance. From 4123987eae4d805615a70cd94de251368e3d88df Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 09:21:34 +0100 Subject: [PATCH 10/26] update --- tests/deprecated_api/test_remove_1-5.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index fd8ddf45b8e1b..e625691f94af2 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -162,9 +162,12 @@ def on_load_checkpoint(self, trainer, *args): chk = ModelCheckpoint(save_last=True) trainer = Trainer( **trainer_kwargs, - callbacks=[NewSignatureOnLoadCheckpoint(), - ValidSignature1(), - ValidSignature2OnLoadCheckpoint(), chk] + callbacks=[ + NewSignatureOnLoadCheckpoint(), + ValidSignature1(), + ValidSignature2OnLoadCheckpoint(), + chk, + ] ) trainer.fit(model) From 5073651a9506ac15053167a171861f046cd6e300 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 09:25:42 +0100 Subject: [PATCH 11/26] update --- tests/deprecated_api/test_remove_1-5.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index e625691f94af2..65aa144beaaca 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -135,14 +135,14 @@ def test_v1_5_0_old_callback_on_load_checkpoint(tmpdir): model = BoringModel() trainer_kwargs = { "default_root_dir": tmpdir, - "max_epochs": 3, + "max_steps": 1, } chk = ModelCheckpoint(save_last=True) trainer = Trainer(**trainer_kwargs, callbacks=[OldSignatureOnLoadCheckpoint(), chk]) trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer_kwargs["max_epochs"] = 5 + trainer_kwargs["max_steps"] = 2 cb = OldSignatureOnLoadCheckpoint() trainer = Trainer(**trainer_kwargs, callbacks=cb, resume_from_checkpoint=chk.last_model_path) trainer.fit(model) @@ -155,10 +155,6 @@ def on_load_checkpoint(self, trainer, *args): self.on_load_checkpoint_called = True model = BoringModel() - trainer_kwargs = { - "default_root_dir": tmpdir, - "max_epochs": 3, - } chk = ModelCheckpoint(save_last=True) trainer = Trainer( **trainer_kwargs, @@ -172,7 +168,6 @@ def on_load_checkpoint(self, trainer, *args): trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): - trainer_kwargs["max_epochs"] = 5 cb, cb_1, cb_2 = NewSignatureOnLoadCheckpoint(), ValidSignature1(), ValidSignature2OnLoadCheckpoint() trainer = Trainer(**trainer_kwargs, callbacks=[cb, cb_1, cb_2], resume_from_checkpoint=chk.last_model_path) trainer.fit(model) From dbb848b5dbd851a11fa2ebae249baa891e43d868 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 09:27:05 +0100 Subject: [PATCH 12/26] formatting --- pytorch_lightning/callbacks/finetuning.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index d434b77104741..b920969f60c19 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -87,8 +87,12 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): def __init__(self): self._internal_state = {} - def on_save_checkpoint(self, trainer: 'pl.Trainer', pl_module: LightningModule, - checkpoint: Dict[str, Any]) -> Dict[str, Any]: + def on_save_checkpoint( + self, + trainer: 'pl.Trainer', + pl_module: LightningModule, + checkpoint: Dict[str, Any], + ) -> Dict[str, Any]: return self._internal_state def on_load_checkpoint( From 5566a5d530029f38c0cee4bf342058a8276c1933 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 09:30:00 +0100 Subject: [PATCH 13/26] add comments --- pytorch_lightning/callbacks/finetuning.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index b920969f60c19..5d442cfe0a1c6 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -99,7 +99,7 @@ def on_load_checkpoint( self, trainer: 'pl.Trainer', pl_module: LightningModule, callback_state: Dict[str, Any] ) -> None: self._internal_state = callback_state - # restore the param_groups created during training. + # restore the param_groups created during the previous training. _map_name_to_p = {n: p for n, p in pl_module.named_parameters()} for opt_idx, optimizer in enumerate(trainer.optimizers): param_groups = self._param_groups_state_to_param_groups( @@ -269,6 +269,10 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): def _add_to_internal_state( self, pl_module: LightningModule, opt_idx: int, current_param_groups: List[Dict[str, Any]] ) -> None: + """ + This function save the new param_group metadata inside `BaseFinetuning` Callback `internal_state`. + The tensors are being mapped to their names for memory optimization. + """ map_p_to_name = {p: n for n, p in pl_module.named_parameters()} for g in current_param_groups: group_state = {k: v for k, v in g.items() if k != 'params'} @@ -279,10 +283,12 @@ def _store( self, pl_module: LightningModule, opt_idx: int, len_previous_param_groups: List[Dict[str, Any]], current_param_groups: List[Dict[str, Any]] ) -> None: + # save the param_groups on first call. if opt_idx not in self._internal_state: self._internal_state[opt_idx] = [] self._add_to_internal_state(pl_module, opt_idx, current_param_groups) + # save new param_groups possibly created by the users. elif len_previous_param_groups != len(current_param_groups): self._add_to_internal_state(pl_module, opt_idx, current_param_groups[len_previous_param_groups:]) From 7e695cd4bc62f83698aa6522bc22303f5731b826 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 11:34:36 +0100 Subject: [PATCH 14/26] update on comments --- pytorch_lightning/callbacks/base.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index c6650c3b49375..7c997f518a5be 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -278,15 +278,17 @@ def on_load_checkpoint( ) -> None: """Called when loading a model checkpoint, use to reload state. - .. note:: - - The ``on_load_checkpoint`` will be called only if a state is provided. - Therefore, ``on_save_checkpoint`` hook need to be overridden to return a state. - Args: trainer: the current Trainer instance. pl_module: the current 'pl.LightningModule' instance. callback_state: the callback state returned by ``on_save_checkpoint``. + + .. note:: + + The ``on_load_checkpoint`` won't be called with an undefined state. + If your ``on_load_checkpoint`` hook behavior doesn't rely on a state, + you will still need to override ``on_save_checkpoint`` with a ``dummy state`` + for the ``on_load_checkpoint`` to be executed. """ pass From 5cda56400f7be59b45fa79a9d5a04ad0875c8687 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 11:35:03 +0100 Subject: [PATCH 15/26] update --- pytorch_lightning/callbacks/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index 7c997f518a5be..a53577468d196 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -287,8 +287,7 @@ def on_load_checkpoint( The ``on_load_checkpoint`` won't be called with an undefined state. If your ``on_load_checkpoint`` hook behavior doesn't rely on a state, - you will still need to override ``on_save_checkpoint`` with a ``dummy state`` - for the ``on_load_checkpoint`` to be executed. + you will still need to override ``on_save_checkpoint`` to return a ``dummy state``. """ pass From efedcdc730a66b75ac564ea791fc0efbc67754f4 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 29 Apr 2021 13:13:20 +0100 Subject: [PATCH 16/26] Update pytorch_lightning/callbacks/base.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/callbacks/base.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a53577468d196..d2fafbad3d2f2 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -283,8 +283,7 @@ def on_load_checkpoint( pl_module: the current 'pl.LightningModule' instance. callback_state: the callback state returned by ``on_save_checkpoint``. - .. note:: - + Note: The ``on_load_checkpoint`` won't be called with an undefined state. If your ``on_load_checkpoint`` hook behavior doesn't rely on a state, you will still need to override ``on_save_checkpoint`` to return a ``dummy state``. From 79e9286688d8f759a2cd3d8830887015d3152fbe Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 14:40:32 +0100 Subject: [PATCH 17/26] update --- pytorch_lightning/callbacks/base.py | 8 +++--- pytorch_lightning/callbacks/finetuning.py | 34 +++++++++++++---------- 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index d2fafbad3d2f2..48c60338a0bf8 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -264,8 +264,8 @@ def on_save_checkpoint( Called when saving a model checkpoint, use to persist state. Args: - trainer: the current Trainer instance. - pl_module: the current 'pl.LightningModule' instance. + trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. + pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance. checkpoint: the checkpoint dictionary that will be saved. Returns: @@ -279,8 +279,8 @@ def on_load_checkpoint( """Called when loading a model checkpoint, use to reload state. Args: - trainer: the current Trainer instance. - pl_module: the current 'pl.LightningModule' instance. + trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. + pl_module: the current :class:`~pytorch_lightning.core.lightning.LightningModule` instance. callback_state: the callback state returned by ``on_save_checkpoint``. Note: diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 5d442cfe0a1c6..5664f3101b057 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -85,7 +85,7 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): """ def __init__(self): - self._internal_state = {} + self._internal_state: Dict[int, Dict[str, torch.Tensor]] = {} def on_save_checkpoint( self, @@ -100,18 +100,18 @@ def on_load_checkpoint( ) -> None: self._internal_state = callback_state # restore the param_groups created during the previous training. - _map_name_to_p = {n: p for n, p in pl_module.named_parameters()} + name_to_param_mapping = dict(pl_module.named_parameters()) for opt_idx, optimizer in enumerate(trainer.optimizers): - param_groups = self._param_groups_state_to_param_groups( - deepcopy(self._internal_state[opt_idx]), _map_name_to_p + param_groups = self._restore_named_parameters( + deepcopy(self._internal_state[opt_idx]), name_to_param_mapping ) optimizer.param_groups = param_groups - def _param_groups_state_to_param_groups( - self, param_groups_state: Dict[str, Any], map_name_to_p: Dict[str, torch.Tensor] + def _restore_named_parameters( + self, param_groups_state: Dict[str, Any], name_to_param_mapping: Dict[str, torch.Tensor] ) -> Dict[str, Any]: for group in param_groups_state: - group["params"] = [map_name_to_p[name] for name in group["params"]] + group["params"] = [name_to_param_mapping[name] for name in group["params"]] return param_groups_state @staticmethod @@ -273,15 +273,18 @@ def _add_to_internal_state( This function save the new param_group metadata inside `BaseFinetuning` Callback `internal_state`. The tensors are being mapped to their names for memory optimization. """ - map_p_to_name = {p: n for n, p in pl_module.named_parameters()} + param_to_name_mapping = {p: n for n, p in pl_module.named_parameters()} for g in current_param_groups: group_state = {k: v for k, v in g.items() if k != 'params'} - group_state['params'] = [map_p_to_name[p] for p in g['params']] + group_state['params'] = [param_to_name_mapping[p] for p in g['params']] self._internal_state[opt_idx].append(group_state) def _store( - self, pl_module: LightningModule, opt_idx: int, len_previous_param_groups: List[Dict[str, Any]], - current_param_groups: List[Dict[str, Any]] + self, + pl_module: LightningModule, + opt_idx: int, + num_param_groups: int, + current_param_groups: List[Dict[str, Any]], ) -> None: # save the param_groups on first call. if opt_idx not in self._internal_state: @@ -289,16 +292,16 @@ def _store( self._add_to_internal_state(pl_module, opt_idx, current_param_groups) # save new param_groups possibly created by the users. - elif len_previous_param_groups != len(current_param_groups): - self._add_to_internal_state(pl_module, opt_idx, current_param_groups[len_previous_param_groups:]) + elif num_param_groups != len(current_param_groups): + self._add_to_internal_state(pl_module, opt_idx, current_param_groups[num_param_groups:]) def on_train_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" for opt_idx, optimizer in trainer.train_loop.prepare_optimizers(): - len_previous_param_groups = len(optimizer.param_groups) + num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) current_param_groups = optimizer.param_groups - self._store(pl_module, opt_idx, len_previous_param_groups, current_param_groups) + self._store(pl_module, opt_idx, num_param_groups, current_param_groups) def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): """ @@ -367,6 +370,7 @@ def __init__( round: int = 12, ): super().__init__() + self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch self.backbone_initial_lr = backbone_initial_lr self.lambda_func = lambda_func From c916cac1d1be36e0715bbb18a10956aada72cd57 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 29 Apr 2021 14:43:05 +0100 Subject: [PATCH 18/26] update --- pytorch_lightning/callbacks/finetuning.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 5664f3101b057..5bbf49f66eb0a 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -85,18 +85,18 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): """ def __init__(self): - self._internal_state: Dict[int, Dict[str, torch.Tensor]] = {} + self._internal_state: Dict[int, Dict[str, Any]] = {} def on_save_checkpoint( self, trainer: 'pl.Trainer', pl_module: LightningModule, checkpoint: Dict[str, Any], - ) -> Dict[str, Any]: + ) -> Dict[int, Dict[str, Any]]: return self._internal_state def on_load_checkpoint( - self, trainer: 'pl.Trainer', pl_module: LightningModule, callback_state: Dict[str, Any] + self, trainer: 'pl.Trainer', pl_module: LightningModule, callback_state: Dict[int, Dict[str, Any]] ) -> None: self._internal_state = callback_state # restore the param_groups created during the previous training. From bf39544179503ce5279be2cc6c830765959d8585 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 17:18:17 +0200 Subject: [PATCH 19/26] Typing and minor changes --- pytorch_lightning/callbacks/finetuning.py | 20 +++++++++----------- pytorch_lightning/trainer/callback_hook.py | 16 ++++++---------- 2 files changed, 15 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 5bbf49f66eb0a..3379fde5f0fd6 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -27,7 +27,6 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -40,7 +39,6 @@ def multiplicative(epoch): class BaseFinetuning(Callback): r""" - This class implements the base logic for writing your own Finetuning Callback. Override ``freeze_before_training`` and ``finetune_function`` methods with your own logic. @@ -85,18 +83,18 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): """ def __init__(self): - self._internal_state: Dict[int, Dict[str, Any]] = {} + self._internal_state: Dict[int, List[Dict[str, Any]]] = {} def on_save_checkpoint( self, trainer: 'pl.Trainer', - pl_module: LightningModule, + pl_module: 'pl.LightningModule', checkpoint: Dict[str, Any], - ) -> Dict[int, Dict[str, Any]]: + ) -> Dict[int, List[Dict[str, Any]]]: return self._internal_state def on_load_checkpoint( - self, trainer: 'pl.Trainer', pl_module: LightningModule, callback_state: Dict[int, Dict[str, Any]] + self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', callback_state: Dict[int, List[Dict[str, Any]]] ) -> None: self._internal_state = callback_state # restore the param_groups created during the previous training. @@ -281,7 +279,7 @@ def _add_to_internal_state( def _store( self, - pl_module: LightningModule, + pl_module: 'pl.LightningModule', opt_idx: int, num_param_groups: int, current_param_groups: List[Dict[str, Any]], @@ -303,13 +301,13 @@ def on_train_epoch_start(self, trainer, pl_module): current_param_groups = optimizer.param_groups self._store(pl_module, opt_idx, num_param_groups, current_param_groups) - def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + def finetune_function(self, pl_module: 'pl.LightningModule', epoch: int, optimizer: Optimizer, opt_idx: int): """ Override to add your unfreeze logic """ raise NotImplementedError - def freeze_before_training(self, pl_module: LightningModule): + def freeze_before_training(self, pl_module: 'pl.LightningModule'): """ Override to add your freeze logic """ @@ -391,10 +389,10 @@ def on_fit_start(self, trainer, pl_module): return raise MisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute") - def freeze_before_training(self, pl_module: LightningModule): + def freeze_before_training(self, pl_module: 'pl.LightningModule'): self.freeze(pl_module.backbone) - def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int): + def finetune_function(self, pl_module: 'pl.LightningModule', epoch: int, optimizer: Optimizer, opt_idx: int): """Called when the epoch begins.""" if epoch == self.unfreeze_backbone_at_epoch: current_lr = optimizer.param_groups[0]['lr'] diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 684b43aec9002..dbf147e51303a 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -270,16 +270,12 @@ def on_keyboard_interrupt(self): @staticmethod def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) - if len(parameters) == 2 and parameters[1] != "args": - return True - return False + return len(parameters) == 2 and parameters[1] != "args" @staticmethod def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) - if len(parameters) == 1 and parameters[0] == "callback_state": - return True - return False + return len(parameters) == 1 def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: """Called when saving a model checkpoint.""" @@ -307,16 +303,16 @@ def on_load_checkpoint(self, checkpoint): if callback_states is not None: for callback in self.callbacks: state = deepcopy(callback_states.get(type(callback))) - if state is not None: + if state: if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint): rank_zero_deprecation( "`Callback.on_load_checkpoint` signature has changed in v1.3." - " `Trainer` and `LightningModule` parameter have been added." + " `trainer` and `pl_module` parameters have been added." " Support for the old signature will be removed in v1.5" ) - state = callback.on_load_checkpoint(state) # noqa: parameter-unfilled + callback.on_load_checkpoint(state) # noqa: parameter-unfilled else: - state = callback.on_load_checkpoint(self, self.lightning_module, state) + callback.on_load_checkpoint(self, self.lightning_module, state) def on_after_backward(self): """ From e81d03ff28893dbb9cec75f36d306c215a368790 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 17:31:09 +0200 Subject: [PATCH 20/26] Refactor --- pytorch_lightning/callbacks/finetuning.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 3379fde5f0fd6..495bc574958ac 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -98,19 +98,17 @@ def on_load_checkpoint( ) -> None: self._internal_state = callback_state # restore the param_groups created during the previous training. - name_to_param_mapping = dict(pl_module.named_parameters()) + named_parameters = dict(pl_module.named_parameters()) for opt_idx, optimizer in enumerate(trainer.optimizers): - param_groups = self._restore_named_parameters( - deepcopy(self._internal_state[opt_idx]), name_to_param_mapping - ) + param_groups = self.__apply_mapping_to_param_groups(self._internal_state[opt_idx], named_parameters) optimizer.param_groups = param_groups - def _restore_named_parameters( - self, param_groups_state: Dict[str, Any], name_to_param_mapping: Dict[str, torch.Tensor] - ) -> Dict[str, Any]: - for group in param_groups_state: - group["params"] = [name_to_param_mapping[name] for name in group["params"]] - return param_groups_state + @staticmethod + def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: + param_groups_ = deepcopy(param_groups) + for group in param_groups_: + group["params"] = [mapping[p] for p in group["params"]] + return param_groups_ @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: @@ -265,7 +263,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): self.freeze_before_training(pl_module) def _add_to_internal_state( - self, pl_module: LightningModule, opt_idx: int, current_param_groups: List[Dict[str, Any]] + self, pl_module: 'pl.LightningModule', opt_idx: int, current_param_groups: List[Dict[str, Any]] ) -> None: """ This function save the new param_group metadata inside `BaseFinetuning` Callback `internal_state`. From e5390261aa6a7894e6fc2041821934f9037bda25 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 17:50:46 +0200 Subject: [PATCH 21/26] Fix deprecated test --- pytorch_lightning/trainer/callback_hook.py | 2 +- tests/deprecated_api/test_remove_1-5.py | 9 +++------ 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index dbf147e51303a..844bbfd2fe24c 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -275,7 +275,7 @@ def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool: @staticmethod def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool: parameters = list(signature(fn).parameters) - return len(parameters) == 1 + return len(parameters) == 1 and parameters[0] != "args" def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]: """Called when saving a model checkpoint.""" diff --git a/tests/deprecated_api/test_remove_1-5.py b/tests/deprecated_api/test_remove_1-5.py index 65aa144beaaca..990622eff2adc 100644 --- a/tests/deprecated_api/test_remove_1-5.py +++ b/tests/deprecated_api/test_remove_1-5.py @@ -165,15 +165,12 @@ def on_load_checkpoint(self, trainer, *args): chk, ] ) - trainer.fit(model) + with no_deprecated_call(match="old signature will be removed in v1.5"): + trainer.fit(model) with pytest.deprecated_call(match="old signature will be removed in v1.5"): - cb, cb_1, cb_2 = NewSignatureOnLoadCheckpoint(), ValidSignature1(), ValidSignature2OnLoadCheckpoint() - trainer = Trainer(**trainer_kwargs, callbacks=[cb, cb_1, cb_2], resume_from_checkpoint=chk.last_model_path) + trainer = Trainer(**trainer_kwargs, resume_from_checkpoint=chk.last_model_path) trainer.fit(model) - assert cb.on_load_checkpoint_called - assert not cb_1.on_load_checkpoint_called - assert cb_2.on_load_checkpoint_called def test_v1_5_0_legacy_profiler_argument(): From 6c2a12f699e04c458c4c7bc6f9ca5bd31ec7dd0d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 18:32:33 +0200 Subject: [PATCH 22/26] Broken commit --- pytorch_lightning/callbacks/finetuning.py | 41 +++++++---------------- 1 file changed, 13 insertions(+), 28 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 495bc574958ac..23a67a18b051a 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -103,13 +103,6 @@ def on_load_checkpoint( param_groups = self.__apply_mapping_to_param_groups(self._internal_state[opt_idx], named_parameters) optimizer.param_groups = param_groups - @staticmethod - def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: - param_groups_ = deepcopy(param_groups) - for group in param_groups_: - group["params"] = [mapping[p] for p in group["params"]] - return param_groups_ - @staticmethod def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]: """ @@ -262,18 +255,11 @@ def unfreeze_and_add_param_group( def on_before_accelerator_backend_setup(self, trainer, pl_module): self.freeze_before_training(pl_module) - def _add_to_internal_state( - self, pl_module: 'pl.LightningModule', opt_idx: int, current_param_groups: List[Dict[str, Any]] - ) -> None: - """ - This function save the new param_group metadata inside `BaseFinetuning` Callback `internal_state`. - The tensors are being mapped to their names for memory optimization. - """ - param_to_name_mapping = {p: n for n, p in pl_module.named_parameters()} - for g in current_param_groups: - group_state = {k: v for k, v in g.items() if k != 'params'} - group_state['params'] = [param_to_name_mapping[p] for p in g['params']] - self._internal_state[opt_idx].append(group_state) + @staticmethod + def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: + for group in param_groups: + group["params"] = [mapping[p] for p in group["params"]] + return param_groups def _store( self, @@ -282,21 +268,20 @@ def _store( num_param_groups: int, current_param_groups: List[Dict[str, Any]], ) -> None: - # save the param_groups on first call. - if opt_idx not in self._internal_state: - self._internal_state[opt_idx] = [] - self._add_to_internal_state(pl_module, opt_idx, current_param_groups) - - # save new param_groups possibly created by the users. - elif num_param_groups != len(current_param_groups): - self._add_to_internal_state(pl_module, opt_idx, current_param_groups[num_param_groups:]) + mapping = {p: n for n, p in pl_module.named_parameters()} + self._internal_state.setdefault(opt_idx, self.__apply_mapping_to_param_groups(current_param_groups, mapping)) + if num_param_groups != len(current_param_groups): + # save new param_groups possibly created by the users. + self._internal_state[opt_idx].extend( + self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping) + ) def on_train_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" for opt_idx, optimizer in trainer.train_loop.prepare_optimizers(): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) - current_param_groups = optimizer.param_groups + current_param_groups = deepcopy(optimizer.param_groups) self._store(pl_module, opt_idx, num_param_groups, current_param_groups) def finetune_function(self, pl_module: 'pl.LightningModule', epoch: int, optimizer: Optimizer, opt_idx: int): From b5b51f34dee39733b2e683422e809eb1f40ca9e3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 18:54:58 +0200 Subject: [PATCH 23/26] Fix broken commit --- pytorch_lightning/callbacks/finetuning.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 23a67a18b051a..4b834e9e9c31e 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -257,9 +257,13 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module): @staticmethod def __apply_mapping_to_param_groups(param_groups: List[Dict[str, Any]], mapping: dict) -> List[Dict[str, Any]]: - for group in param_groups: - group["params"] = [mapping[p] for p in group["params"]] - return param_groups + output = [] + for g in param_groups: + # skip params to save memory + group_state = {k: v for k, v in g.items() if k != 'params'} + group_state['params'] = [mapping[p] for p in g['params']] + output.append(group_state) + return output def _store( self, @@ -269,8 +273,9 @@ def _store( current_param_groups: List[Dict[str, Any]], ) -> None: mapping = {p: n for n, p in pl_module.named_parameters()} - self._internal_state.setdefault(opt_idx, self.__apply_mapping_to_param_groups(current_param_groups, mapping)) - if num_param_groups != len(current_param_groups): + if opt_idx not in self._internal_state: + self._internal_state[opt_idx] = self.__apply_mapping_to_param_groups(current_param_groups, mapping) + elif num_param_groups != len(current_param_groups): # save new param_groups possibly created by the users. self._internal_state[opt_idx].extend( self.__apply_mapping_to_param_groups(current_param_groups[num_param_groups:], mapping) @@ -281,7 +286,7 @@ def on_train_epoch_start(self, trainer, pl_module): for opt_idx, optimizer in trainer.train_loop.prepare_optimizers(): num_param_groups = len(optimizer.param_groups) self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx) - current_param_groups = deepcopy(optimizer.param_groups) + current_param_groups = optimizer.param_groups self._store(pl_module, opt_idx, num_param_groups, current_param_groups) def finetune_function(self, pl_module: 'pl.LightningModule', epoch: int, optimizer: Optimizer, opt_idx: int): From cb17f0a61f19ce10656e87dbd0ec5221fc6692ab Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 19:11:40 +0200 Subject: [PATCH 24/26] flake8 --- pytorch_lightning/callbacks/finetuning.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/callbacks/finetuning.py b/pytorch_lightning/callbacks/finetuning.py index 4b834e9e9c31e..d3f52b4ba9a15 100644 --- a/pytorch_lightning/callbacks/finetuning.py +++ b/pytorch_lightning/callbacks/finetuning.py @@ -17,7 +17,6 @@ Freeze and unfreeze models for finetuning purposes """ import logging -from copy import deepcopy from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union import torch From 9dea46c734dd56aa76fb8f737e7a5405b36335bb Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 29 Apr 2021 19:13:05 +0200 Subject: [PATCH 25/26] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8efefe922c302..9748b2735c02b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -203,6 +203,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Deprecated `@auto_move_data` in favor of `trainer.predict` ([#6993](https://github.com/PyTorchLightning/pytorch-lightning/pull/6993)) +- Deprecated `Callback.on_load_checkpoint(checkpoint)` in favor of `Callback.on_load_checkpoint(trainer, pl_module, checkpoint)` ([#7253](https://github.com/PyTorchLightning/pytorch-lightning/pull/7253)) + + - Deprecated metrics in favor of `torchmetrics` ([#6505](https://github.com/PyTorchLightning/pytorch-lightning/pull/6505), [#6530](https://github.com/PyTorchLightning/pytorch-lightning/pull/6530), [#6540](https://github.com/PyTorchLightning/pytorch-lightning/pull/6540), From aea5cce24e16144aa515538a24b918991e02a441 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 30 Apr 2021 13:43:28 +0100 Subject: [PATCH 26/26] update on comments --- tests/callbacks/test_finetuning_callback.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/callbacks/test_finetuning_callback.py b/tests/callbacks/test_finetuning_callback.py index 971172577e421..c8290f217a289 100644 --- a/tests/callbacks/test_finetuning_callback.py +++ b/tests/callbacks/test_finetuning_callback.py @@ -25,7 +25,7 @@ from tests.helpers import BoringModel, RandomDataset -class TestTestBackboneFinetuningCallbackCallback(BackboneFinetuning): +class TestBackboneFinetuningCallback(BackboneFinetuning): def on_train_epoch_end(self, trainer, pl_module, outputs): epoch = trainer.current_epoch @@ -71,7 +71,7 @@ def train_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=2) model = FinetuningBoringModel() - callback = TestTestBackboneFinetuningCallbackCallback(unfreeze_backbone_at_epoch=3, verbose=False) + callback = TestBackboneFinetuningCallback(unfreeze_backbone_at_epoch=3, verbose=False) trainer = Trainer( limit_train_batches=4,