-
Notifications
You must be signed in to change notification settings - Fork 552
/
kd_tiny_rtmdet_s_neck_300e_coco.py
99 lines (94 loc) · 4.01 KB
/
kd_tiny_rtmdet_s_neck_300e_coco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
_base_ = '../rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py'
teacher_ckpt = 'https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco/rtmdet_s_syncbn_fast_8xb32-300e_coco_20221230_182329-0a8c901a.pth' # noqa: E501
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
model = dict(
_delete_=True,
_scope_='mmrazor',
type='FpnTeacherDistill',
architecture=dict(
cfg_path='mmyolo::rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco.py'),
teacher=dict(
cfg_path='mmyolo::rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py'),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
# `recorders` are used to record various intermediate results during
# the model forward.
student_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
),
teacher_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
# `connectors` are adaptive layers which usually map teacher's and
# students features to the same dimension.
connectors=dict(
fpn0_s=dict(
type='ConvModuleConnector',
in_channel=96,
out_channel=128,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None),
fpn0_t=dict(
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
fpn1_s=dict(
type='ConvModuleConnector',
in_channel=96,
out_channel=128,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None),
fpn1_t=dict(
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
fpn2_s=dict(
type='ConvModuleConnector',
in_channel=96,
out_channel=128,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None),
fpn2_t=dict(
type='NormConnector', in_channels=128, norm_cfg=norm_cfg)),
distill_losses=dict(
loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
# `loss_forward_mappings` are mappings between distill loss forward
# arguments and records.
loss_forward_mappings=dict(
loss_fpn0=dict(
preds_S=dict(
from_student=True, recorder='fpn0', connector='fpn0_s'),
preds_T=dict(
from_student=False, recorder='fpn0', connector='fpn0_t')),
loss_fpn1=dict(
preds_S=dict(
from_student=True, recorder='fpn1', connector='fpn1_s'),
preds_T=dict(
from_student=False, recorder='fpn1', connector='fpn1_t')),
loss_fpn2=dict(
preds_S=dict(
from_student=True, recorder='fpn2', connector='fpn2_s'),
preds_T=dict(
from_student=False, recorder='fpn2',
connector='fpn2_t')))))
find_unused_parameters = True
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=_base_.max_epochs - _base_.num_epochs_stage2,
switch_pipeline=_base_.train_pipeline_stage2),
# stop distillation after the 280th epoch
dict(type='mmrazor.StopDistillHook', stop_epoch=280)
]