Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change backbone finetuning strategy to allow for DDP #205

Merged
merged 1 commit into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions lightning_pose/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,73 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
# Dan: removed buffer; seems to complicate checkpoint loading
# pl_module.register_buffer(self.attr_name, torch.tensor(value))
setattr(pl_module, self.attr_name, torch.tensor(value))


class UnfreezeBackbone(Callback):
"""Callback that ramps up the backbone learning rate from 0 to `upsampling_lr` on
`unfreeze_epoch`.

Starts LR at `initial_ratio * upsampling_lr`. Grows lr by a factor of `epoch_ratio` per
epoch. Once LR reaches `upsampling_lr`, keeps it in sync with `upsampling_lr`.

Use instead of pl.callbacks.BackboneFinetuning in order to use multi-GPU (DDP). See
lightning-ai/pytorch-lightning#20340 for context.
"""

def __init__(
self,
unfreeze_epoch,
initial_ratio=0.1,
epoch_ratio=1.5,
):
self.unfreeze_epoch = unfreeze_epoch
self.initial_ratio = initial_ratio
self.epoch_ratio = epoch_ratio
self._warmed_up = False

def on_train_epoch_start(self, trainer, pl_module):
# This callback is only applicable to heatmap models but we
# might encounter a RegressionModel.
if not hasattr(pl_module, "upsampling_layers"):
return

# Once backbone_lr warms up to upsampling_lr, this callback does nothing.
# Control of backbone lr is then the sole job of the main lr scheduler.
if self._warmed_up:
return

optimizer = pl_module.optimizers()
# Check our assumptions about param group indices
assert optimizer.param_groups[0]["name"] == "backbone"
assert optimizer.param_groups[1]["name"].startswith("upsampling")

upsampling_lr = optimizer.param_groups[1]["lr"]

optimizer.param_groups[0]["lr"] = self._get_backbone_lr(
pl_module.current_epoch, upsampling_lr
)

def _get_backbone_lr(self, current_epoch, upsampling_lr):
assert not self._warmed_up

# Before unfreeze, learning_rate is 0.
if current_epoch < self.unfreeze_epoch:
return 0.0

# On unfreeze, initialize learning rate.
# Remember this initial value for warm up.
if current_epoch == self.unfreeze_epoch:
self._initial_lr = self.initial_ratio * upsampling_lr
return self._initial_lr

# Warm up: compute inital_ratio * epoch_ratio ** epochs_since_thaw.
# Use stored initial_ratio rather than recomputing it since
# upsampling_lr is subject to change via the scheduler.
if current_epoch > self.unfreeze_epoch:
epochs_since_thaw = current_epoch - self.unfreeze_epoch
next_lr = min(
self._initial_lr * self.epoch_ratio**epochs_since_thaw, upsampling_lr
)
if next_lr == upsampling_lr:
self._warmed_up = True
return next_lr
68 changes: 11 additions & 57 deletions lightning_pose/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,7 @@ def forward(
"""
return self.get_representations(images)

def get_parameters(self):
return filter(lambda p: p.requires_grad, self.parameters())

def get_scheduler(self, optimizer):

# define a scheduler that reduces the base learning rate
if self.lr_scheduler == "multisteplr" or self.lr_scheduler == "multistep_lr":

Expand All @@ -393,6 +389,17 @@ def get_scheduler(self, optimizer):

return scheduler

def get_parameters(self):
if getattr(self, "upsampling_layers", None) is not None:
params = [
{"params": self.backbone.parameters(), "lr": 0, "name": "backbone"},
{"params": self.upsampling_layers.parameters(), "name": "upsampling"},
]
else:
params = filter(lambda p: p.requires_grad, self.parameters())

return params

def configure_optimizers(self) -> dict:
"""Select optimizer, lr scheduler, and metric for monitoring."""

Expand Down Expand Up @@ -499,23 +506,6 @@ def test_step(
"""Base test step, a wrapper around the `evaluate_labeled` method."""
self.evaluate_labeled(batch_dict, "test")

def get_parameters(self):

if getattr(self, "upsampling_layers", None) is not None:
# single optimizer with single learning rate
params = [
# don't uncomment line below;
# the BackboneFinetuning callback should add backbone to the params
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers.parameters()},
]
else:
# standard adam optimizer
params = filter(lambda p: p.requires_grad, self.parameters())

return params


@typechecked
class SemiSupervisedTrackerMixin(object):
Expand Down Expand Up @@ -589,39 +579,3 @@ def training_step(
self.log("total_loss", total_loss, prog_bar=True)

return {"loss": total_loss}

ksikka marked this conversation as resolved.
Show resolved Hide resolved
def get_parameters(self):

if getattr(self, "upsampling_layers", None) is not None:
# if we're here this is a heatmap model
params = [
# don't uncomment line below;
# the BackboneFinetuning callback should add backbone to the params
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers.parameters()},
]

else:
# standard adam optimizer for regression model
params = filter(lambda p: p.requires_grad, self.parameters())

return params

# # single optimizer with different learning rates
# def configure_optimizers(self):
# params_net = [
# # {"params": self.backbone.parameters()},
# # don't uncomment above line; the BackboneFinetuning callback should add
# # backbone to the params.
# {
# "params": self.upsampling_layers.parameters()
# }, # important that this is the 0th element, for BackboneFineTuning
# ]
# optimizer = Adam(params_net, lr=1e-3)
# scheduler = MultiStepLR(optimizer, milestones=[100, 200, 300], gamma=0.5)
#
# optimizers = [optimizer]
# lr_schedulers = [scheduler]
#
# return optimizers, lr_schedulers
23 changes: 6 additions & 17 deletions lightning_pose/models/heatmap_tracker_mhcrnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ def predict_step(

def get_parameters(self):
params = [
# don't uncomment line below
# the BackboneFinetuning callback should add backbone to the params.
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers_rnn.parameters()},
{"params": self.upsampling_layers_sf.parameters()},
{"params": self.backbone.parameters(), "name": "backbone", "lr": 0.0},
{
"params": self.upsampling_layers_rnn.parameters(),
"name": "upsampling_rnn",
},
{"params": self.upsampling_layers_sf.parameters(), "name": "upsampling_sf"},
]
return params

Expand Down Expand Up @@ -361,17 +361,6 @@ def get_loss_inputs_unlabeled(
"confidences": torch.cat([confidence_crnn, confidence_sf], dim=0),
}

def get_parameters(self):
params = [
# don't uncomment line below
# the BackboneFinetuning callback should add backbone to the params.
# {"params": self.backbone.parameters()},
# important this is the 0th element, for BackboneFinetuning callback
{"params": self.upsampling_layers_rnn.parameters()},
{"params": self.upsampling_layers_sf.parameters()},
]
return params


class UpsamplingCRNN(torch.nn.Module):
"""Bidirectional Convolutional RNN network that handles heatmaps of context frames.
Expand Down
26 changes: 13 additions & 13 deletions lightning_pose/utils/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from omegaconf import DictConfig, OmegaConf
from typeguard import typechecked

from lightning_pose.callbacks import AnnealWeight
from lightning_pose.callbacks import AnnealWeight, UnfreezeBackbone
from lightning_pose.data.augmentations import imgaug_transform
from lightning_pose.data.datamodules import BaseDataModule, UnlabeledDataModule
from lightning_pose.data.datasets import (
Expand Down Expand Up @@ -286,8 +286,13 @@ def get_model(
) -> pl.LightningModule:
"""Create model: regression or heatmap based, supervised or semi-supervised."""

lr_scheduler = cfg.training["lr_scheduler"]
lr_scheduler_params = cfg.training["lr_scheduler_params"][lr_scheduler]
lr_scheduler = cfg.training.lr_scheduler

lr_scheduler_params = OmegaConf.to_object(
cfg.training.lr_scheduler_params[lr_scheduler]
)
lr_scheduler_params["unfreeze_backbone_at_epoch"] = cfg.training.unfreezing_epoch

semi_supervised = check_if_semi_supervised(cfg.model.losses_to_use)
image_h = cfg.data.image_resize_dims.height
image_w = cfg.data.image_resize_dims.width
Expand Down Expand Up @@ -430,7 +435,12 @@ def get_callbacks(
)
callbacks.append(early_stopping)

if backbone_unfreeze:
unfreeze_backbone_callback = UnfreezeBackbone(cfg.training.unfreezing_epoch)
callbacks.append(unfreeze_backbone_callback)

if lr_monitor:
# this callback should be added after UnfreezeBackbone in order to log its learning rate
lr_monitor = pl.callbacks.LearningRateMonitor(logging_interval="epoch")
callbacks.append(lr_monitor)

Expand All @@ -451,16 +461,6 @@ def get_callbacks(
)
callbacks.append(ckpt_callback)

if backbone_unfreeze:
transfer_unfreeze_callback = pl.callbacks.BackboneFinetuning(
unfreeze_backbone_at_epoch=cfg.training.unfreezing_epoch,
lambda_func=lambda epoch: 1.5,
backbone_initial_ratio_lr=0.1,
should_align=True,
train_bn=True,
)
callbacks.append(transfer_unfreeze_callback)

# we just need this callback for unsupervised models
if (cfg.model.losses_to_use != []) and (cfg.model.losses_to_use is not None):
anneal_weight_callback = AnnealWeight(**cfg.callbacks.anneal_weight)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def video_dataloader(cfg, base_dataset, video_list) -> LitDaliWrapper:
def trainer(cfg) -> pl.Trainer:
"""Create a basic pytorch lightning trainer for testing models."""

cfg.training.unfreezing_epoch = 10 # force no unfreezing to keep memory reqs of tests down
cfg.training.unfreezing_epoch = 1 # exercise unfreezing
callbacks = get_callbacks(cfg, early_stopping=False, lr_monitor=False, backbone_unfreeze=True)

trainer = pl.Trainer(
Expand Down
29 changes: 29 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from lightning_pose.callbacks import UnfreezeBackbone


def test_unfreeze_backbone():
unfreeze_backbone = UnfreezeBackbone(2, initial_ratio=0.1, epoch_ratio=1.5)

# Test unfreezing at epoch 2.
ksikka marked this conversation as resolved.
Show resolved Hide resolved
assert unfreeze_backbone._get_backbone_lr(0, 1e-3) == 0.0
assert unfreeze_backbone._get_backbone_lr(1, 1e-3) == 0.0
assert (
unfreeze_backbone._get_backbone_lr(2, 1e-3) == 1e-3 * 0.1
) # upsampling_lr * initial_ratio

# Test warming up.
ksikka marked this conversation as resolved.
Show resolved Hide resolved
# We thawed at upsampling_lr = 1e-3. Henceforth, backbone_lr should be
# agnostic to changes in upsampling_lr so long as we are not fully
# "warmed up".
assert unfreeze_backbone._get_backbone_lr(3, 1e-3) == 1e-3 * 0.1 * 1.5
assert unfreeze_backbone._get_backbone_lr(3, 1.5e-3) == 1e-3 * 0.1 * 1.5

assert unfreeze_backbone._get_backbone_lr(4, 1e-3) == 1e-4 * 1.5 * 1.5
assert unfreeze_backbone._get_backbone_lr(4, 1.5e-3) == 1e-4 * 1.5 * 1.5

# Once we hit upsampling_lr, set the _warmed_up bit to stop this callback
# from setting backbone lr in the future, allowing the normal scheduler to take over.
assert not unfreeze_backbone._warmed_up
# current_epoch set to some high value to trigger "warmed up" condition
assert unfreeze_backbone._get_backbone_lr(15, 1e-3) == 1e-3
assert unfreeze_backbone._warmed_up