Skip to content

Commit

Permalink
[Fix] Fix init weights in Swin and PVT. (#6663)
Browse files Browse the repository at this point in the history
  • Loading branch information
RangiLyu authored and ZwwWayne committed Dec 14, 2021
1 parent 8080c46 commit be19628
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
7 changes: 2 additions & 5 deletions mmdet/models/backbones/pvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,15 +529,12 @@ def init_weights(self):
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[
1] * m.out_channels
fan_out //= m.groups
normal_init(m.weight, 0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
constant_init(m.bias, 0)
normal_init(m, 0, math.sqrt(2.0 / fan_out))
elif isinstance(m, AbsolutePositionEmbedding):
m.init_weights()
else:
Expand Down
3 changes: 1 addition & 2 deletions mmdet/models/backbones/swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,8 +678,7 @@ def init_weights(self):
if isinstance(m, nn.Linear):
trunc_normal_init(m, std=.02, bias=0.)
elif isinstance(m, nn.LayerNorm):
constant_init(m.bias, 0)
constant_init(m.weight, 1.0)
constant_init(m, 1.0)
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
Expand Down

0 comments on commit be19628

Please sign in to comment.