Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bugfix] Add reloading support using BaseFinetuning #7253

Merged
merged 29 commits into from
Apr 30, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed metrics not being properly logged with `precision=16` and `manual_optimization` ([#7228](https://github.com/PyTorchLightning/pytorch-lightning/pull/7228))


- Fixed `BaseFinetuning` properly reloading `optimizer_states` when using `resume_from_checkpoint` ([#6891](https://github.com/PyTorchLightning/pytorch-lightning/pull/6891))


- Fixed `parameters_to_ignore` not properly set to DDPWrapper ([#7239](https://github.com/PyTorchLightning/pytorch-lightning/pull/7239))


Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,21 @@ def on_save_checkpoint(
"""
pass

def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
def on_load_checkpoint(
self, trainer: 'pl.Trainer', pl_module: 'pl.LightningModule', callback_state: Dict[str, Any]
) -> None:
"""Called when loading a model checkpoint, use to reload state.

Args:
trainer: the current Trainer instance.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
pl_module: the current 'pl.LightningModule' instance.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
callback_state: the callback state returned by ``on_save_checkpoint``.

.. note::

tchaton marked this conversation as resolved.
Show resolved Hide resolved
The ``on_load_checkpoint`` won't be called with an undefined state.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
If your ``on_load_checkpoint`` hook behavior doesn't rely on a state,
you will still need to override ``on_save_checkpoint`` to return a ``dummy state``.
"""
pass

Expand Down
67 changes: 64 additions & 3 deletions pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
Freeze and unfreeze models for finetuning purposes
"""
import logging
from typing import Callable, Generator, Iterable, List, Optional, Union
from copy import deepcopy
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union

import torch
from torch.nn import Module
from torch.nn.modules.batchnorm import _BatchNorm
from torch.optim.optimizer import Optimizer

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -68,7 +70,7 @@ def __init__(self, unfreeze_at_epoch=10)
self._unfreeze_at_epoch = unfreeze_at_epoch

def freeze_before_training(self, pl_module):
# freeze any module you want
# freeze any module you want
# Here, we are freezing ``feature_extractor``
self.freeze(pl_module.feature_extractor)

Expand All @@ -82,6 +84,36 @@ def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
)
"""

def __init__(self):
self._internal_state = {}
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def on_save_checkpoint(
self,
trainer: 'pl.Trainer',
pl_module: LightningModule,
checkpoint: Dict[str, Any],
) -> Dict[str, Any]:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self._internal_state

def on_load_checkpoint(
self, trainer: 'pl.Trainer', pl_module: LightningModule, callback_state: Dict[str, Any]
) -> None:
self._internal_state = callback_state
# restore the param_groups created during the previous training.
_map_name_to_p = {n: p for n, p in pl_module.named_parameters()}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for opt_idx, optimizer in enumerate(trainer.optimizers):
param_groups = self._param_groups_state_to_param_groups(
deepcopy(self._internal_state[opt_idx]), _map_name_to_p
)
optimizer.param_groups = param_groups

def _param_groups_state_to_param_groups(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self, param_groups_state: Dict[str, Any], map_name_to_p: Dict[str, torch.Tensor]
) -> Dict[str, Any]:
for group in param_groups_state:
group["params"] = [map_name_to_p[name] for name in group["params"]]
return param_groups_state

@staticmethod
def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -> List[Module]:
"""
Expand Down Expand Up @@ -234,10 +266,39 @@ def unfreeze_and_add_param_group(
def on_before_accelerator_backend_setup(self, trainer, pl_module):
self.freeze_before_training(pl_module)

def _add_to_internal_state(
self, pl_module: LightningModule, opt_idx: int, current_param_groups: List[Dict[str, Any]]
) -> None:
"""
This function save the new param_group metadata inside `BaseFinetuning` Callback `internal_state`.
The tensors are being mapped to their names for memory optimization.
"""
map_p_to_name = {p: n for n, p in pl_module.named_parameters()}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for g in current_param_groups:
group_state = {k: v for k, v in g.items() if k != 'params'}
tchaton marked this conversation as resolved.
Show resolved Hide resolved
group_state['params'] = [map_p_to_name[p] for p in g['params']]
self._internal_state[opt_idx].append(group_state)

def _store(
self, pl_module: LightningModule, opt_idx: int, len_previous_param_groups: List[Dict[str, Any]],
tchaton marked this conversation as resolved.
Show resolved Hide resolved
current_param_groups: List[Dict[str, Any]]
) -> None:
# save the param_groups on first call.
if opt_idx not in self._internal_state:
self._internal_state[opt_idx] = []
self._add_to_internal_state(pl_module, opt_idx, current_param_groups)

# save new param_groups possibly created by the users.
elif len_previous_param_groups != len(current_param_groups):
self._add_to_internal_state(pl_module, opt_idx, current_param_groups[len_previous_param_groups:])

def on_train_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
for opt_idx, optimizer in trainer.train_loop.prepare_optimizers():
len_previous_param_groups = len(optimizer.param_groups)
self.finetune_function(pl_module, trainer.current_epoch, optimizer, opt_idx)
current_param_groups = optimizer.param_groups
self._store(pl_module, opt_idx, len_previous_param_groups, current_param_groups)

def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
"""
Expand Down Expand Up @@ -305,6 +366,7 @@ def __init__(
verbose: bool = False,
round: int = 12,
):
super().__init__()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.unfreeze_backbone_at_epoch = unfreeze_backbone_at_epoch
self.backbone_initial_lr = backbone_initial_lr
self.lambda_func = lambda_func
Expand All @@ -330,7 +392,6 @@ def freeze_before_training(self, pl_module: LightningModule):

def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
"""Called when the epoch begins."""

if epoch == self.unfreeze_backbone_at_epoch:
current_lr = optimizer.param_groups[0]['lr']
initial_backbone_lr = self.backbone_initial_lr if self.backbone_initial_lr is not None \
Expand Down
26 changes: 20 additions & 6 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,17 +268,24 @@ def on_keyboard_interrupt(self):
callback.on_keyboard_interrupt(self, self.lightning_module)

@staticmethod
def __is_old_signature(fn: Callable) -> bool:
def __is_old_signature_on_save_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
if len(parameters) == 2 and parameters[1] != "args":
return True
return False

@staticmethod
def __is_old_signature_on_load_checkpoint(fn: Callable) -> bool:
parameters = list(signature(fn).parameters)
if len(parameters) == 1 and parameters[0] == "callback_state":
return True
return False

def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
if self.__is_old_signature(callback.on_save_checkpoint):
if self.__is_old_signature_on_save_checkpoint(callback.on_save_checkpoint):
rank_zero_deprecation(
"`Callback.on_save_checkpoint` signature has changed in v1.3."
" A `checkpoint` parameter has been added."
Expand All @@ -299,10 +306,17 @@ def on_load_checkpoint(self, checkpoint):
# https://github.com/pytorch/xla/issues/2773
if callback_states is not None:
for callback in self.callbacks:
state = callback_states.get(type(callback))
if state:
state = deepcopy(state)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
callback.on_load_checkpoint(state)
state = deepcopy(callback_states.get(type(callback)))
if state is not None:
if self.__is_old_signature_on_load_checkpoint(callback.on_load_checkpoint):
rank_zero_deprecation(
"`Callback.on_load_checkpoint` signature has changed in v1.3."
" `Trainer` and `LightningModule` parameter have been added."
" Support for the old signature will be removed in v1.5"
)
state = callback.on_load_checkpoint(state) # noqa: parameter-unfilled
else:
state = callback.on_load_checkpoint(self, self.lightning_module, state)

def on_after_backward(self):
"""
Expand Down
115 changes: 87 additions & 28 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,29 @@
import pytest
import torch
from torch import nn
from torch.optim import SGD
from torch.optim import Optimizer, SGD
from torch.utils.data import DataLoader

from pytorch_lightning import LightningModule, seed_everything, Trainer
from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning
from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint
from pytorch_lightning.callbacks.base import Callback
from tests.helpers import BoringModel, RandomDataset


class TestTestBackboneFinetuningCallbackCallback(BackboneFinetuning):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def on_train_epoch_end(self, trainer, pl_module, outputs):
epoch = trainer.current_epoch
if self.unfreeze_backbone_at_epoch <= epoch:
optimizer = trainer.optimizers[0]
current_lr = optimizer.param_groups[0]['lr']
backbone_lr = self.previous_backbone_lr
if epoch < 6:
assert backbone_lr <= current_lr
else:
assert backbone_lr == current_lr


def test_finetuning_callback(tmpdir):
"""Test finetuning callbacks works as expected"""

Expand Down Expand Up @@ -56,24 +70,11 @@ def configure_optimizers(self):
def train_dataloader(self):
return DataLoader(RandomDataset(32, 64), batch_size=2)

class TestCallback(BackboneFinetuning):

def on_train_epoch_end(self, trainer, pl_module, outputs):
epoch = trainer.current_epoch
if self.unfreeze_backbone_at_epoch <= epoch:
optimizer = trainer.optimizers[0]
current_lr = optimizer.param_groups[0]['lr']
backbone_lr = self.previous_backbone_lr
if epoch < 6:
assert backbone_lr <= current_lr
else:
assert backbone_lr == current_lr

model = FinetuningBoringModel()
callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False)
callback = TestTestBackboneFinetuningCallbackCallback(unfreeze_backbone_at_epoch=3, verbose=False)

trainer = Trainer(
limit_train_batches=1,
limit_train_batches=4,
default_root_dir=tmpdir,
callbacks=[callback],
max_epochs=8,
Expand All @@ -83,6 +84,17 @@ def on_train_epoch_end(self, trainer, pl_module, outputs):
assert model.backbone.has_been_used


class TestBackboneFinetuningWarningCallback(BackboneFinetuning):

def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
"""Called when the epoch begins."""

if epoch == 0:
self.unfreeze_and_add_param_group(
pl_module.backbone, optimizer, 0.1, train_bn=self.train_bn, initial_denom_lr=self.initial_denom_lr
)


def test_finetuning_callback_warning(tmpdir):
"""Test finetuning callbacks works as expected"""

Expand Down Expand Up @@ -113,30 +125,24 @@ def configure_optimizers(self):
optimizer = torch.optim.SGD(self.parameters(), lr=0.1)
return optimizer

class TestCallback(BackboneFinetuning):

def finetune_function(self, pl_module, epoch: int, optimizer, opt_idx: int):
"""Called when the epoch begins."""

if epoch == 0:
self.unfreeze_and_add_param_group(
pl_module.backbone, optimizer, 0.1, train_bn=self.train_bn, initial_denom_lr=self.initial_denom_lr
)
chk = ModelCheckpoint(dirpath=tmpdir, save_last=True)

model = FinetuningBoringModel()
model.validation_step = None
callback = TestCallback(unfreeze_backbone_at_epoch=3, verbose=False)
callback = TestBackboneFinetuningWarningCallback(unfreeze_backbone_at_epoch=3, verbose=False)

with pytest.warns(UserWarning, match="Did you init your optimizer in"):
trainer = Trainer(
limit_train_batches=1,
default_root_dir=tmpdir,
callbacks=[callback],
callbacks=[callback, chk],
max_epochs=2,
)
trainer.fit(model)

assert model.backbone.has_been_used
trainer = Trainer(max_epochs=3, resume_from_checkpoint=chk.last_model_path)
trainer.fit(model)


def test_freeze_unfreeze_function(tmpdir):
Expand Down Expand Up @@ -220,6 +226,59 @@ def __init__(self):
assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight)


class OnEpochLayerFinetuning(BaseFinetuning):

def freeze_before_training(self, pl_module: LightningModule):
self.freeze(pl_module.layer)

def finetune_function(self, pl_module: LightningModule, epoch: int, optimizer: Optimizer, opt_idx: int):
self.unfreeze_and_add_param_group(pl_module.layer[epoch + 1], optimizer)


def test_base_finetuning_internal_state(tmpdir):
"""Test the param_groups updates are properly saved within the internal state of the BaseFinetuning Callbacks"""

seed_everything(42)

class FreezeModel(BoringModel):

def __init__(self):
super().__init__()
self.layer = nn.Sequential(
nn.Linear(32, 32, bias=False),
nn.Linear(32, 32, bias=True),
nn.Linear(32, 32, bias=False),
nn.Linear(32, 32, bias=True),
nn.Linear(32, 32, bias=False),
nn.Linear(32, 2, bias=True),
)

def forward(self, x):
return self.layer(x)

def configure_optimizers(self):
return torch.optim.SGD(self.layer[0].parameters(), lr=0.1)

cb = OnEpochLayerFinetuning()
chk = ModelCheckpoint(dirpath=tmpdir, save_last=True)
model = FreezeModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=5, limit_train_batches=1, callbacks=[cb, chk])
trainer.fit(model)
assert len(cb._internal_state[0]) == 6
assert cb._internal_state[0][0]["params"] == ['layer.0.weight']
assert cb._internal_state[0][1]["params"] == ['layer.1.weight', 'layer.1.bias']
assert cb._internal_state[0][2]["params"] == ['layer.2.weight']
assert cb._internal_state[0][3]["params"] == ['layer.3.weight', 'layer.3.bias']
assert cb._internal_state[0][4]["params"] == ['layer.4.weight']
assert cb._internal_state[0][5]["params"] == ['layer.5.weight', 'layer.5.bias']

model = FreezeModel()
cb = OnEpochLayerFinetuning()
trainer = Trainer(max_epochs=10, resume_from_checkpoint=chk.last_model_path, callbacks=[cb])
with pytest.raises(IndexError, match="index 6 is out of range"):
trainer.fit(model)


def test_on_before_accelerator_backend_setup(tmpdir):
"""
`on_before_accelerator_backend_setup` hook is used by finetuning callbacks to freeze the model before
Expand Down
Loading