Skip to content

Commit

Permalink
Change backbone finetuning strategy to allow for DDP
Browse files Browse the repository at this point in the history
See Lightning-AI/pytorch-lightning#20340 for the original issue.
The implemented workaround no longer freezes/unfreezes layers
via require_grad, instead we only modulate the learning rate
of the backbone layer.
  • Loading branch information
ksikka committed Oct 15, 2024
1 parent ad6ae86 commit de97315
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 86 deletions.
58 changes: 58 additions & 0 deletions lightning_pose/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,61 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
# Dan: removed buffer; seems to complicate checkpoint loading
# pl_module.register_buffer(self.attr_name, torch.tensor(value))
setattr(pl_module, self.attr_name, torch.tensor(value))


class UnfreezeBackbone(Callback):
"""Callback that ramps up the backbone learning rate from 0 to `upsampling_lr` on
`unfreeze_epoch`.
Starts LR at `initial_ratio * upsampling_lr`. Grows lr by a factor of `epoch_ratio` per
epoch. Once LR reaches `upsampling_lr`, keeps it in sync with `upsampling_lr`.
Use instead of pl.callbacks.BackboneFinetuning in order to use multi-GPU (DDP). See
lightning-ai/pytorch-lightning#20340 for context.
"""

def __init__(
self,
unfreeze_epoch,
initial_ratio=0.1,
epoch_ratio=1.5,
):
self.unfreeze_epoch = unfreeze_epoch
self.initial_ratio = initial_ratio
self.epoch_ratio = epoch_ratio
self._align_backbone_lr_with_upsampling_lr = False

def on_train_epoch_start(self, trainer, pl_module):
optimizer = pl_module.optimizers()
# This callback is only applicable to heatmap models but we
# might encounter a RegressionModel.
if not hasattr(pl_module, "upsampling_layers"):
return

assert optimizer.param_groups[0]["name"] == "backbone"
assert optimizer.param_groups[1]["name"].startswith("upsampling")

backbone_lr = optimizer.param_groups[0]["lr"]
upsampling_lr = optimizer.param_groups[1]["lr"]

optimizer.param_groups[0]["lr"] = self._get_next_backbone_lr(
pl_module.current_epoch, backbone_lr, upsampling_lr
)

def _get_next_backbone_lr(self, current_epoch, backbone_lr, upsampling_lr):
if self._align_backbone_lr_with_upsampling_lr:
return upsampling_lr

if current_epoch < self.unfreeze_epoch:
assert backbone_lr == 0.0, backbone_lr
return 0.0

if current_epoch == self.unfreeze_epoch:
assert backbone_lr == 0.0, backbone_lr
return self.initial_ratio * upsampling_lr

if current_epoch > self.unfreeze_epoch:
next_lr = min(backbone_lr * self.epoch_ratio, upsampling_lr)
if next_lr == upsampling_lr:
self._align_backbone_lr_with_upsampling_lr = True
return next_lr
69 changes: 12 additions & 57 deletions lightning_pose/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,7 @@ def forward(
"""
return self.get_representations(images)

def get_parameters(self):
return filter(lambda p: p.requires_grad, self.parameters())

def get_scheduler(self, optimizer):

# define a scheduler that reduces the base learning rate
if self.lr_scheduler == "multisteplr" or self.lr_scheduler == "multistep_lr":

Expand All @@ -393,6 +389,17 @@ def get_scheduler(self, optimizer):

return scheduler

def get_parameters(self):
if getattr(self, "upsampling_layers", None) is not None:
params = [
{"params": self.backbone.parameters(), "lr": 0, "name": "backbone"},
{"params": self.upsampling_layers.parameters(), "name": "upsampling"},
]
else:
params = filter(lambda p: p.requires_grad, self.parameters())

return params

def configure_optimizers(self) -> dict:
"""Select optimizer, lr scheduler, and metric for monitoring."""

Expand All @@ -401,6 +408,7 @@ def configure_optimizers(self) -> dict:

# init standard adam optimizer
optimizer = Adam(params, lr=1e-3)
self._optimizer = optimizer

# get learning rate scheduler
scheduler = self.get_scheduler(optimizer)
Expand Down Expand Up @@ -499,23 +507,6 @@ def test_step(
"""Base test step, a wrapper around the `evaluate_labeled` method."""
self.evaluate_labeled(batch_dict, "test")

def get_parameters(self):

if getattr(self, "upsampling_layers", None) is not None:
# single optimizer with single learning rate
params = [
# don't uncomment line below;
# the BackboneFinetuning callback should add backbone to the params
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers.parameters()},
]
else:
# standard adam optimizer
params = filter(lambda p: p.requires_grad, self.parameters())

return params


@typechecked
class SemiSupervisedTrackerMixin(object):
Expand Down Expand Up @@ -589,39 +580,3 @@ def training_step(
self.log("total_loss", total_loss, prog_bar=True)

return {"loss": total_loss}

def get_parameters(self):

if getattr(self, "upsampling_layers", None) is not None:
# if we're here this is a heatmap model
params = [
# don't uncomment line below;
# the BackboneFinetuning callback should add backbone to the params
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers.parameters()},
]

else:
# standard adam optimizer for regression model
params = filter(lambda p: p.requires_grad, self.parameters())

return params

# # single optimizer with different learning rates
# def configure_optimizers(self):
# params_net = [
# # {"params": self.backbone.parameters()},
# # don't uncomment above line; the BackboneFinetuning callback should add
# # backbone to the params.
# {
# "params": self.upsampling_layers.parameters()
# }, # important that this is the 0th element, for BackboneFineTuning
# ]
# optimizer = Adam(params_net, lr=1e-3)
# scheduler = MultiStepLR(optimizer, milestones=[100, 200, 300], gamma=0.5)
#
# optimizers = [optimizer]
# lr_schedulers = [scheduler]
#
# return optimizers, lr_schedulers
23 changes: 6 additions & 17 deletions lightning_pose/models/heatmap_tracker_mhcrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ def predict_step(

def get_parameters(self):
params = [
# don't uncomment line below
# the BackboneFinetuning callback should add backbone to the params.
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers_rnn.parameters()},
{"params": self.upsampling_layers_sf.parameters()},
{"params": self.backbone.parameters(), "name": "backbone", "lr": 0.0},
{
"params": self.upsampling_layers_rnn.parameters(),
"name": "upsampling_rnn",
},
{"params": self.upsampling_layers_sf.parameters(), "name": "upsampling_sf"},
]
return params

Expand Down Expand Up @@ -361,17 +361,6 @@ def get_loss_inputs_unlabeled(
"confidences": torch.cat([confidence_crnn, confidence_sf], dim=0),
}

def get_parameters(self):
params = [
# don't uncomment line below
# the BackboneFinetuning callback should add backbone to the params.
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers_rnn.parameters()},
{"params": self.upsampling_layers_sf.parameters()},
]
return params


class UpsamplingCRNN(torch.nn.Module):
"""Bidirectional Convolutional RNN network that handles heatmaps of context frames.
Expand Down
21 changes: 10 additions & 11 deletions lightning_pose/utils/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from omegaconf import DictConfig, OmegaConf
from typeguard import typechecked

from lightning_pose.callbacks import AnnealWeight
from lightning_pose.callbacks import AnnealWeight, UnfreezeBackbone
from lightning_pose.data.augmentations import imgaug_transform
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.datasets import (
Expand Down Expand Up @@ -286,8 +286,13 @@ def get_model(
) -> pl.LightningModule:
"""Create model: regression or heatmap based, supervised or semi-supervised."""

lr_scheduler = cfg.training["lr_scheduler"]
lr_scheduler_params = cfg.training["lr_scheduler_params"][lr_scheduler]
lr_scheduler = cfg.training.lr_scheduler

lr_scheduler_params = OmegaConf.to_object(
cfg.training.lr_scheduler_params[lr_scheduler]
)
lr_scheduler_params["unfreeze_backbone_at_epoch"] = cfg.training.unfreezing_epoch

semi_supervised = check_if_semi_supervised(cfg.model.losses_to_use)
image_h = cfg.data.image_resize_dims.height
image_w = cfg.data.image_resize_dims.width
Expand Down Expand Up @@ -452,14 +457,8 @@ def get_callbacks(
callbacks.append(ckpt_callback)

if backbone_unfreeze:
transfer_unfreeze_callback = pl.callbacks.BackboneFinetuning(
unfreeze_backbone_at_epoch=cfg.training.unfreezing_epoch,
lambda_func=lambda epoch: 1.5,
backbone_initial_ratio_lr=0.1,
should_align=True,
train_bn=True,
)
callbacks.append(transfer_unfreeze_callback)
unfreeze_backbone_callback = UnfreezeBackbone(cfg.training.unfreezing_epoch)
callbacks.append(unfreeze_backbone_callback)

# we just need this callback for unsupervised models
if (cfg.model.losses_to_use != []) and (cfg.model.losses_to_use is not None):
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def video_dataloader(cfg, base_dataset, video_list) -> LitDaliWrapper:
def trainer(cfg) -> pl.Trainer:
"""Create a basic pytorch lightning trainer for testing models."""

cfg.training.unfreezing_epoch = 10 # force no unfreezing to keep memory reqs of tests down
cfg.training.unfreezing_epoch = 1 # exercise unfreezing
callbacks = get_callbacks(cfg, early_stopping=False, lr_monitor=False, backbone_unfreeze=True)

trainer = pl.Trainer(
Expand Down
22 changes: 22 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from lightning_pose.callbacks import UnfreezeBackbone

def test_unfreeze_backbone():
unfreeze_backbone = UnfreezeBackbone(2)

# Test unfreezing at epoch 2.
assert unfreeze_backbone._get_next_backbone_lr(0, 0.0, 1e-3) == 0.0
assert unfreeze_backbone._get_next_backbone_lr(1, 0.0, 1e-3) == 0.0
assert unfreeze_backbone._get_next_backbone_lr(2, 0.0, 1e-3) == 1e-4

# Test warming up.
assert unfreeze_backbone._get_next_backbone_lr(3, 1e-4, 1e-3) == 1e-4 * 1.5
assert unfreeze_backbone._get_next_backbone_lr(4, 1e-4 * 1.5, 1e-3) == 1e-4 * 1.5 * 1.5

# Test aliging with upsampling_lr.
assert not unfreeze_backbone._align_backbone_lr_with_upsampling_lr
unfreeze_backbone._get_next_backbone_lr(5, 0.9e-3, 1e-3) == 1e-3
assert unfreeze_backbone._align_backbone_lr_with_upsampling_lr
# from now on, for any backbone_lr it should always return upsampling_lr.
unfreeze_backbone._get_next_backbone_lr(6, 1, 1e-2) == 1e-2
unfreeze_backbone._get_next_backbone_lr(6, 1, 1e-3) == 1e-2
unfreeze_backbone._get_next_backbone_lr(6, 1, 1e-4) == 1e-4

0 comments on commit de97315

Please sign in to comment.