From be19628f32d236ca314ca0659ffafefac1e505c3 Mon Sep 17 00:00:00 2001 From: RangiLyu Date: Tue, 7 Dec 2021 19:56:44 +0800 Subject: [PATCH] [Fix] Fix init weights in Swin and PVT. (#6663) --- mmdet/models/backbones/pvt.py | 7 ++----- mmdet/models/backbones/swin.py | 3 +-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/mmdet/models/backbones/pvt.py b/mmdet/models/backbones/pvt.py index c5365c53c8e..1680dd69d3f 100644 --- a/mmdet/models/backbones/pvt.py +++ b/mmdet/models/backbones/pvt.py @@ -529,15 +529,12 @@ def init_weights(self): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): - constant_init(m.bias, 0) - constant_init(m.weight, 1.0) + constant_init(m, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[ 1] * m.out_channels fan_out //= m.groups - normal_init(m.weight, 0, math.sqrt(2.0 / fan_out)) - if m.bias is not None: - constant_init(m.bias, 0) + normal_init(m, 0, math.sqrt(2.0 / fan_out)) elif isinstance(m, AbsolutePositionEmbedding): m.init_weights() else: diff --git a/mmdet/models/backbones/swin.py b/mmdet/models/backbones/swin.py index 316aec9765f..96b95c0c32f 100644 --- a/mmdet/models/backbones/swin.py +++ b/mmdet/models/backbones/swin.py @@ -678,8 +678,7 @@ def init_weights(self): if isinstance(m, nn.Linear): trunc_normal_init(m, std=.02, bias=0.) elif isinstance(m, nn.LayerNorm): - constant_init(m.bias, 0) - constant_init(m.weight, 1.0) + constant_init(m, 1.0) else: assert 'checkpoint' in self.init_cfg, f'Only support ' \ f'specify `Pretrained` in ' \