Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor]: Unified parameter initialization #4750

Merged
merged 39 commits into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
e9d2589
Support RetinaNet
hhaAndroid Mar 11, 2021
23dbf69
Update RetinaNet init_cfg
hhaAndroid Mar 11, 2021
96fa90f
Update RetinaNet init_cfg
hhaAndroid Mar 11, 2021
55a99e0
Update model
hhaAndroid Mar 15, 2021
2753c7c
Update all model
hhaAndroid Mar 16, 2021
a3b4bd5
Conflict Resolution
hhaAndroid Mar 16, 2021
60ab09e
Fix type error
hhaAndroid Mar 16, 2021
44c5916
Support specify init_cfg
hhaAndroid Mar 17, 2021
ae84822
Support override init_cfg
hhaAndroid Mar 17, 2021
6efd99f
Support ModuleList and Seq
hhaAndroid Mar 18, 2021
ed2d6b7
Use ModuleList
hhaAndroid Mar 18, 2021
2a69fcc
Update init_cfg
hhaAndroid Mar 18, 2021
59205ff
Add docstr and support Caffe2Xavier
hhaAndroid Mar 22, 2021
c732dfb
Conflict Resolution
hhaAndroid Mar 22, 2021
baa11cc
Update init_weight
hhaAndroid Mar 24, 2021
4741aac
Fix Sequential
hhaAndroid Mar 25, 2021
8118d4a
Fix regnet
qwe12369 Mar 29, 2021
e1ad635
Fix BN init_cfg
hhaAndroid Mar 31, 2021
f57c7c7
Conflict resolution
hhaAndroid Mar 31, 2021
e4c010b
Fix error
hhaAndroid Mar 31, 2021
fddac29
Fix unittest
hhaAndroid Mar 31, 2021
c08fb55
Fix init error
hhaAndroid Mar 31, 2021
f627077
Fix bn name
hhaAndroid Mar 31, 2021
5f5b222
Fix resnet unittest
hhaAndroid Mar 31, 2021
d2ab217
Fix unittest
hhaAndroid Apr 6, 2021
48a7b9f
Fix ssd and yolact
hhaAndroid Apr 7, 2021
8aa3b3b
Fix layer error
hhaAndroid Apr 8, 2021
8d6eede
Fix point_rend
hhaAndroid Apr 8, 2021
cdb3ca4
Fix htc
hhaAndroid Apr 13, 2021
6370313
Rename init_weight to init_weights
hhaAndroid Apr 20, 2021
5dfbc21
merge master
hhaAndroid Apr 21, 2021
acef24b
delete mmcv link
hhaAndroid Apr 21, 2021
9d0c595
Fix assert error
hhaAndroid Apr 22, 2021
2180ec2
Fix ssd init
hhaAndroid Apr 25, 2021
486d6d2
Fix carafe init
hhaAndroid Apr 25, 2021
b6f7967
merge master
hhaAndroid Apr 28, 2021
aeab726
fix lint
hhaAndroid Apr 28, 2021
4a9fbd2
update mmcv version
hhaAndroid Apr 28, 2021
a11e276
merge master
hhaAndroid Apr 28, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 32 additions & 23 deletions mmdet/models/backbones/darknet.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) 2019 Western Digital Corporation or its affiliates.

import logging
import warnings

import torch.nn as nn
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm

from ..builder import BACKBONES


class ResBlock(nn.Module):
class ResBlock(BaseModule):
"""The basic residual block used in Darknet. Each ResBlock consists of two
ConvModules and the input is added to the final output. Each ConvModule is
composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer
Expand All @@ -31,8 +31,9 @@ def __init__(self,
in_channels,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1)):
super(ResBlock, self).__init__()
act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
init_cfg=None):
super(ResBlock, self).__init__(init_cfg)
assert in_channels % 2 == 0 # ensure the in_channels is even
half_in_channels = in_channels // 2

Expand All @@ -53,7 +54,7 @@ def forward(self, x):


@BACKBONES.register_module()
class Darknet(nn.Module):
class Darknet(BaseModule):
"""Darknet backbone.

Args:
Expand Down Expand Up @@ -98,10 +99,16 @@ def __init__(self,
conv_cfg=None,
norm_cfg=dict(type='BN', requires_grad=True),
act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
norm_eval=True):
super(Darknet, self).__init__()
norm_eval=True,
pretrained=None,
init_cfg=None):
super(Darknet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for darknet')

assert not (init_cfg and pretrained), \
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
'init_cfg and pretrained cannot be setting at the same time'

self.depth = depth
self.out_indices = out_indices
self.frozen_stages = frozen_stages
Expand All @@ -122,6 +129,22 @@ def __init__(self,

self.norm_eval = norm_eval

if isinstance(pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated '
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
'key, please consider using init_cfg')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
else:
raise TypeError('pretrained must be a str or None')

def forward(self, x):
outs = []
for i, layer_name in enumerate(self.cr_blocks):
Expand All @@ -132,20 +155,6 @@ def forward(self, x):

return tuple(outs)

def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

else:
raise TypeError('pretrained must be a str or None')

def _freeze_stages(self):
if self.frozen_stages >= 0:
for i in range(self.frozen_stages):
Expand Down
12 changes: 3 additions & 9 deletions mmdet/models/backbones/detectors_resnet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_conv_layer, build_norm_layer, constant_init
from mmcv.cnn import build_conv_layer, build_norm_layer

from ..builder import BACKBONES
from .resnet import Bottleneck as _Bottleneck
Expand Down Expand Up @@ -56,12 +56,8 @@ def __init__(self,
1,
stride=1,
bias=True)
self.init_weights()

def init_weights(self):
"""Initialize the weights."""
if self.rfp_inplanes:
constant_init(self.rfp_conv, 0)
self.init_cfg = dict(
override=dict(type='Constant', val=0, name='rfp_conv'))
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved

def rfp_forward(self, x, rfp_feat):
"""The forward function that also takes the RFP features as input."""
Expand Down Expand Up @@ -230,13 +226,11 @@ def __init__(self,
stage_with_sac=(False, False, False, False),
rfp_inplanes=None,
output_img=False,
pretrained=None,
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
**kwargs):
self.sac = sac
self.stage_with_sac = stage_with_sac
self.rfp_inplanes = rfp_inplanes
self.output_img = output_img
self.pretrained = pretrained
super(DetectoRS_ResNet, self).__init__(**kwargs)

self.inplanes = self.stem_channels
Expand Down
18 changes: 11 additions & 7 deletions mmdet/models/backbones/hourglass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule

from ..builder import BACKBONES
from ..utils import ResLayer
from .resnet import BasicBlock


class HourglassModule(nn.Module):
class HourglassModule(BaseModule):
"""Hourglass Module for HourglassNet backbone.

Generate module recursively and use BasicBlock as the base unit.
Expand All @@ -24,8 +25,9 @@ def __init__(self,
depth,
stage_channels,
stage_blocks,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HourglassModule, self).__init__()
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=None):
super(HourglassModule, self).__init__(init_cfg)

self.depth = depth

Expand Down Expand Up @@ -78,7 +80,7 @@ def forward(self, x):


@BACKBONES.register_module()
class HourglassNet(nn.Module):
class HourglassNet(BaseModule):
"""HourglassNet backbone.

Stacked Hourglass Networks for Human Pose Estimation.
Expand Down Expand Up @@ -115,8 +117,9 @@ def __init__(self,
stage_channels=(256, 256, 384, 384, 384, 512),
stage_blocks=(2, 2, 2, 2, 2, 4),
feat_channel=256,
norm_cfg=dict(type='BN', requires_grad=True)):
super(HourglassNet, self).__init__()
norm_cfg=dict(type='BN', requires_grad=True),
init_cfg=None):
super(HourglassNet, self).__init__(init_cfg)

self.num_stacks = num_stacks
assert self.num_stacks >= 1
Expand Down Expand Up @@ -161,7 +164,8 @@ def __init__(self,

self.relu = nn.ReLU(inplace=True)

def init_weights(self, pretrained=None):
# TODO: How to convert to init_cfg
def init_weight(self):
"""Init module weights.

We do nothing in this function because all modules we used
Expand Down
81 changes: 42 additions & 39 deletions mmdet/models/backbones/hrnet.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import warnings

import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
kaiming_init)
from mmcv.runner import load_checkpoint
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm

from mmdet.utils import get_root_logger
from ..builder import BACKBONES
from .resnet import BasicBlock, Bottleneck


class HRModule(nn.Module):
class HRModule(BaseModule):
"""High-Resolution Module for HRNet.

In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
Expand All @@ -25,8 +25,9 @@ def __init__(self,
multiscale_output=True,
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN')):
super(HRModule, self).__init__()
norm_cfg=dict(type='BN'),
init_cfg=None):
super(HRModule, self).__init__(init_cfg)
self._check_branches(num_branches, num_blocks, in_channels,
num_channels)

Expand All @@ -46,17 +47,17 @@ def _check_branches(self, num_branches, num_blocks, in_channels,
num_channels):
if num_branches != len(num_blocks):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_BLOCKS({len(num_blocks)})'
f'!= NUM_BLOCKS({len(num_blocks)})'
raise ValueError(error_msg)

if num_branches != len(num_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_CHANNELS({len(num_channels)})'
f'!= NUM_CHANNELS({len(num_channels)})'
raise ValueError(error_msg)

if num_branches != len(in_channels):
error_msg = f'NUM_BRANCHES({num_branches}) ' \
f'!= NUM_INCHANNELS({len(in_channels)})'
f'!= NUM_INCHANNELS({len(in_channels)})'
raise ValueError(error_msg)

def _make_one_branch(self,
Expand Down Expand Up @@ -195,7 +196,7 @@ def forward(self, x):


@BACKBONES.register_module()
class HRNet(nn.Module):
class HRNet(BaseModule):
"""HRNet backbone.

High-Resolution Representations for Labeling Pixels and Regions
Expand Down Expand Up @@ -263,8 +264,10 @@ def __init__(self,
norm_cfg=dict(type='BN'),
norm_eval=True,
with_cp=False,
zero_init_residual=False):
super(HRNet, self).__init__()
zero_init_residual=False,
pretrained=None,
init_cfg=None):
super(HRNet, self).__init__(init_cfg)
self.extra = extra
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
Expand Down Expand Up @@ -344,6 +347,32 @@ def __init__(self,
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels)

if isinstance(self.pretrained, str):
warnings.warn('DeprecationWarning: pretrained is a deprecated '
'key, please consider using init_cfg')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)
elif pretrained is None:
if init_cfg is None:
self.init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
val=1,
layer=['_BatchNorm', 'GroupNorm'])
]
elif not isinstance(init_cfg, list):
self.init_cfg = [init_cfg]

if self.zero_init_residual:
self.init_cfg += [
dict(
type='Constant',
layer=['BatchNorm2', 'BatchNorm3'],
hhaAndroid marked this conversation as resolved.
Show resolved Hide resolved
val=0),
]
else:
raise TypeError('pretrained must be a str or None')

@property
def norm1(self):
"""nn.Module: the normalization layer named "norm1" """
Expand Down Expand Up @@ -464,32 +493,6 @@ def _make_stage(self, layer_config, in_channels, multiscale_output=True):

return nn.Sequential(*hr_modules), in_channels

def init_weights(self, pretrained=None):
"""Initialize the weights in backbone.

Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if isinstance(pretrained, str):
logger = get_root_logger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)

if self.zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
constant_init(m.norm3, 0)
elif isinstance(m, BasicBlock):
constant_init(m.norm2, 0)
else:
raise TypeError('pretrained must be a str or None')

def forward(self, x):
"""Forward function."""
x = self.conv1(x)
Expand Down
4 changes: 3 additions & 1 deletion mmdet/models/backbones/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ def __init__(self,
stage_with_dcn=(False, False, False, False),
plugins=None,
with_cp=False,
zero_init_residual=True):
zero_init_residual=True,
pretrained=None,
init_cfg=None):
super(ResNet, self).__init__()

# Generate RegNet parameters first
Expand Down
Loading