-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Change backbone finetuning strategy to allow for DDP
See Lightning-AI/pytorch-lightning#20340 for the original issue. The implemented workaround no longer freezes/unfreezes layers via require_grad, instead we only modulate the learning rate of the backbone layer.
- Loading branch information
Showing
6 changed files
with
109 additions
and
86 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from lightning_pose.callbacks import UnfreezeBackbone | ||
|
||
def test_unfreeze_backbone(): | ||
unfreeze_backbone = UnfreezeBackbone(2) | ||
|
||
# Test unfreezing at epoch 2. | ||
assert unfreeze_backbone._get_next_backbone_lr(0, 0.0, 1e-3) == 0.0 | ||
assert unfreeze_backbone._get_next_backbone_lr(1, 0.0, 1e-3) == 0.0 | ||
assert unfreeze_backbone._get_next_backbone_lr(2, 0.0, 1e-3) == 1e-4 | ||
|
||
# Test warming up. | ||
assert unfreeze_backbone._get_next_backbone_lr(3, 1e-4, 1e-3) == 1e-4 * 1.5 | ||
assert unfreeze_backbone._get_next_backbone_lr(4, 1e-4 * 1.5, 1e-3) == 1e-4 * 1.5 * 1.5 | ||
|
||
# Test aliging with upsampling_lr. | ||
assert not unfreeze_backbone._align_backbone_lr_with_upsampling_lr | ||
unfreeze_backbone._get_next_backbone_lr(5, 0.9e-3, 1e-3) == 1e-3 | ||
assert unfreeze_backbone._align_backbone_lr_with_upsampling_lr | ||
# from now on, for any backbone_lr it should always return upsampling_lr. | ||
unfreeze_backbone._get_next_backbone_lr(6, 1, 1e-2) == 1e-2 | ||
unfreeze_backbone._get_next_backbone_lr(6, 1, 1e-3) == 1e-2 | ||
unfreeze_backbone._get_next_backbone_lr(6, 1, 1e-4) == 1e-4 |