Skip to content

Commit

Permalink
Fix module dict in base finetuning (#8170)
Browse files Browse the repository at this point in the history
* Fix module dict in base finetuning

* Update CHANGELOG.md
  • Loading branch information
ethanwharris authored and lexierule committed Jul 1, 2021
1 parent 2856e62 commit d317ebf
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
5 changes: 4 additions & 1 deletion pytorch_lightning/callbacks/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Union

import torch
from torch.nn import Module
from torch.nn import Module, ModuleDict
from torch.nn.modules.batchnorm import _BatchNorm
from torch.optim.optimizer import Optimizer

Expand Down Expand Up @@ -114,6 +114,9 @@ def flatten_modules(modules: Union[Module, Iterable[Union[Module, Iterable]]]) -
Returns:
List of modules
"""
if isinstance(modules, ModuleDict):
modules = modules.values()

if isinstance(modules, Iterable):
_modules = []
for m in modules:
Expand Down
12 changes: 7 additions & 5 deletions tests/callbacks/test_finetuning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,15 +330,17 @@ class ConvBlockParam(nn.Module):

def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(in_channels, out_channels, 3)
self.act = nn.ReLU()
self.module_dict = nn.ModuleDict({
"conv": nn.Conv2d(in_channels, out_channels, 3),
"act": nn.ReLU(),
})
# add trivial test parameter to convblock to validate parent (non-leaf) module parameter handling
self.parent_param = nn.Parameter(torch.zeros((1), dtype=torch.float))
self.bn = nn.BatchNorm2d(out_channels)

def forward(self, x):
x = self.conv(x)
x = self.act(x)
x = self.module_dict["conv"](x)
x = self.module_dict["act"](x)
return self.bn(x)

model = nn.Sequential(
Expand All @@ -352,7 +354,7 @@ def forward(self, x):
assert len(BaseFinetuning.flatten_modules(model)) == 10

BaseFinetuning.freeze(model.encoder, train_bn=True)
assert not model.encoder[0].conv.weight.requires_grad # Validate a leaf module parameter is frozen
assert not model.encoder[0].module_dict["conv"].weight.requires_grad # Validate a leaf module parameter is frozen
assert not model.encoder[0].parent_param.requires_grad # Validate the parent module parameter is frozen
assert model.encoder[0].bn.weight.requires_grad

Expand Down

0 comments on commit d317ebf

Please sign in to comment.