Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor]: Unified parameter initialization #4750

Merged
merged 39 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e9d2589
Support RetinaNet
hhaAndroid Mar 11, 2021
23dbf69
Update RetinaNet init_cfg
hhaAndroid Mar 11, 2021
96fa90f
Update RetinaNet init_cfg
hhaAndroid Mar 11, 2021
55a99e0
Update model
hhaAndroid Mar 15, 2021
2753c7c
Update all model
hhaAndroid Mar 16, 2021
a3b4bd5
Conflict Resolution
hhaAndroid Mar 16, 2021
60ab09e
Fix type error
hhaAndroid Mar 16, 2021
44c5916
Support specify init_cfg
hhaAndroid Mar 17, 2021
ae84822
Support override init_cfg
hhaAndroid Mar 17, 2021
6efd99f
Support ModuleList and Seq
hhaAndroid Mar 18, 2021
ed2d6b7
Use ModuleList
hhaAndroid Mar 18, 2021
2a69fcc
Update init_cfg
hhaAndroid Mar 18, 2021
59205ff
Add docstr and support Caffe2Xavier
hhaAndroid Mar 22, 2021
c732dfb
Conflict Resolution
hhaAndroid Mar 22, 2021
baa11cc
Update init_weight
hhaAndroid Mar 24, 2021
4741aac
Fix Sequential
hhaAndroid Mar 25, 2021
8118d4a
Fix regnet
qwe12369 Mar 29, 2021
e1ad635
Fix BN init_cfg
hhaAndroid Mar 31, 2021
f57c7c7
Conflict resolution
hhaAndroid Mar 31, 2021
e4c010b
Fix error
hhaAndroid Mar 31, 2021
fddac29
Fix unittest
hhaAndroid Mar 31, 2021
c08fb55
Fix init error
hhaAndroid Mar 31, 2021
f627077
Fix bn name
hhaAndroid Mar 31, 2021
5f5b222
Fix resnet unittest
hhaAndroid Mar 31, 2021
d2ab217
Fix unittest
hhaAndroid Apr 6, 2021
48a7b9f
Fix ssd and yolact
hhaAndroid Apr 7, 2021
8aa3b3b
Fix layer error
hhaAndroid Apr 8, 2021
8d6eede
Fix point_rend
hhaAndroid Apr 8, 2021
cdb3ca4
Fix htc
hhaAndroid Apr 13, 2021
6370313
Rename init_weight to init_weights
hhaAndroid Apr 20, 2021
5dfbc21
merge master
hhaAndroid Apr 21, 2021
acef24b
delete mmcv link
hhaAndroid Apr 21, 2021
9d0c595
Fix assert error
hhaAndroid Apr 22, 2021
2180ec2
Fix ssd init
hhaAndroid Apr 25, 2021
486d6d2
Fix carafe init
hhaAndroid Apr 25, 2021
b6f7967
merge master
hhaAndroid Apr 28, 2021
aeab726
fix lint
hhaAndroid Apr 28, 2021
4a9fbd2
update mmcv version
hhaAndroid Apr 28, 2021
a11e276
merge master
hhaAndroid Apr 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 36 additions & 23 deletions mmdet/models/backbones/darknet.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) 2019 Western Digital Corporation or its affiliates.

import logging
import warnings

import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm

from ..builder import BACKBONES


class ResBlock(nn.Module):
class ResBlock(BaseModule):
"""The basic residual block used in Darknet. Each ResBlock consists of two
ConvModules and the input is added to the final output. Each ConvModule is
composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer
Expand All @@ -25,14 +25,17 @@ class ResBlock(nn.Module):
Default: dict(type='BN', requires_grad=True)
act_cfg (dict): Config dict for activation layer.
Default: dict(type='LeakyReLU', negative_slope=0.1).
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""

def __init__(self,
in_channels,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1)):
super(ResBlock, self).__init__()
act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
init_cfg=None):
super(ResBlock, self).__init__(init_cfg)
assert in_channels % 2 == 0 # ensure the in_channels is even
half_in_channels = in_channels // 2

Expand All @@ -53,7 +56,7 @@ def forward(self, x):


@BACKBONES.register_module()
class Darknet(nn.Module):
class Darknet(BaseModule):
"""Darknet backbone.

Args:
Expand All @@ -69,6 +72,9 @@ class Darknet(nn.Module):
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None

Example:
>>> from mmdet.models import Darknet
Expand Down Expand Up @@ -98,10 +104,13 @@ def __init__(self,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
norm_eval=True):
super(Darknet, self).__init__()
norm_eval=True,
pretrained=None,
init_cfg=None):
super(Darknet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for darknet')

self.depth = depth
self.out_indices = out_indices
self.frozen_stages = frozen_stages
Expand All @@ -122,6 +131,24 @@ def __init__(self,

self.norm_eval = norm_eval

assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated, '
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')

def forward(self, x):
outs = []
for i, layer_name in enumerate(self.cr_blocks):
Expand All @@ -132,20 +159,6 @@ def forward(self, x):

return tuple(outs)

def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

else:
raise TypeError('pretrained must be a str or None')

def _freeze_stages(self):
if self.frozen_stages >= 0:
for i in range(self.frozen_stages):
Expand Down
23 changes: 11 additions & 12 deletions mmdet/models/backbones/detectors_resnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer, constant_init
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import Sequential

from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
Expand All @@ -22,6 +23,8 @@ class Bottleneck(_Bottleneck):
added for ``rfp_feat``. Otherwise, the structure is the same as
base class.
sac (dict, optional): Dictionary to construct SAC. Default: None.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""
expansion = 4

Expand All @@ -30,8 +33,10 @@ def __init__(self,
planes,
rfp_inplanes=None,
sac=None,
init_cfg=None,
**kwargs):
super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
super(Bottleneck, self).__init__(
inplanes, planes, init_cfg=init_cfg, **kwargs)

assert sac is None or isinstance(sac, dict)
self.sac = sac
Expand All @@ -56,12 +61,9 @@ def __init__(self,
1,
stride=1,
bias=True)
self.init_weights()

def init_weights(self):
"""Initialize the weights."""
if self.rfp_inplanes:
constant_init(self.rfp_conv, 0)
if init_cfg is None:
self.init_cfg = dict(
type='Constant', val=0, override=dict(name='rfp_conv'))

def rfp_forward(self, x, rfp_feat):
"""The forward function that also takes the RFP features as input."""
Expand Down Expand Up @@ -110,7 +112,7 @@ def _inner_forward(x):
return out


class ResLayer(nn.Sequential):
class ResLayer(Sequential):
"""ResLayer to build ResNet style backbone for RPF in detectoRS.

The difference between this module and base class is that we pass
Expand Down Expand Up @@ -216,7 +218,6 @@ class DetectoRS_ResNet(ResNet):
base class.
output_img (bool): If ``True``, the input image will be inserted into
the starting position of output. Default: False.
pretrained (str, optional): The pretrained model to load.
"""

arch_settings = {
Expand All @@ -230,13 +231,11 @@ def __init__(self,
stage_with_sac=(False, False, False, False),
rfp_inplanes=None,
output_img=False,
pretrained=None,
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
**kwargs):
self.sac = sac
self.stage_with_sac = stage_with_sac
self.rfp_inplanes = rfp_inplanes
self.output_img = output_img
self.pretrained = pretrained
super(DetectoRS_ResNet, self).__init__(**kwargs)

self.inplanes = self.stem_channels
Expand Down
36 changes: 20 additions & 16 deletions mmdet/models/backbones/hourglass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import BasicBlock


class HourglassModule(nn.Module):
class HourglassModule(BaseModule):
"""Hourglass Module for HourglassNet backbone.

Generate module recursively and use BasicBlock as the base unit.
Expand All @@ -18,14 +19,17 @@ class HourglassModule(nn.Module):
stage_blocks (list[int]): Number of sub-modules stacked in current and
follow-up HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None
"""

def __init__(self,
depth,
stage_channels,
stage_blocks,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HourglassModule, self).__init__()
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=None):
super(HourglassModule, self).__init__(init_cfg)

self.depth = depth

Expand Down Expand Up @@ -78,7 +82,7 @@ def forward(self, x):


@BACKBONES.register_module()
class HourglassNet(nn.Module):
class HourglassNet(BaseModule):
"""HourglassNet backbone.

Stacked Hourglass Networks for Human Pose Estimation.
Expand All @@ -95,6 +99,9 @@ class HourglassNet(nn.Module):
HourglassModule.
feat_channel (int): Feature channel of conv after a HourglassModule.
norm_cfg (dict): Dictionary to construct and config norm layer.
pretrained (str, optional): model pretrained path. Default: None
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None

Example:
>>> from mmdet.models import HourglassNet
Expand All @@ -115,8 +122,12 @@ def __init__(self,
stage_channels=(256, 256, 384, 384, 384, 512),
stage_blocks=(2, 2, 2, 2, 2, 4),
feat_channel=256,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HourglassNet, self).__init__()
norm_cfg=dict(type='BN', requires_grad=True),
pretrained=None,
init_cfg=None):
assert init_cfg is None, 'To prevent abnormal initialization ' \
'behavior, init_cfg is not allowed to be set'
super(HourglassNet, self).__init__(init_cfg)

self.num_stacks = num_stacks
assert self.num_stacks >= 1
Expand Down Expand Up @@ -161,17 +172,10 @@ def __init__(self,

self.relu = nn.ReLU(inplace=True)

def init_weights(self, pretrained=None):
"""Init module weights.

We do nothing in this function because all modules we used
(ConvModule, BasicBlock and etc.) have default initialization, and
currently we don't provide pretrained model of HourglassNet.

Detector's __init__() will call backbone's init_weights() with
pretrained as input, so we keep this function.
"""
def init_weight(self):
"""Init module weights."""
# Training Centripetal Model needs to reset parameters for Conv2d
super(HourglassNet, self).init_weight()
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.reset_parameters()
Expand Down
Loading