Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify model initialization #235

Merged
merged 11 commits into from
Aug 10, 2021
Prev Previous commit
Next Next commit
unify model initialization for the rest code
GT9505 committed Aug 8, 2021

Verified

This commit was signed with the committer’s verified signature.
mbklein Michael B. Klein
commit 4f84fc6b54b7da10aaf8ab2695567d7420de5ade
5 changes: 3 additions & 2 deletions configs/_base_/models/cascade_mask_rcnn_r50_fpn.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
model = dict(
detector=dict(
type='CascadeRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
@@ -11,7 +10,9 @@
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
5 changes: 3 additions & 2 deletions configs/_base_/models/cascade_rcnn_r50_fpn.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
model = dict(
detector=dict(
type='CascadeRCNN',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
@@ -11,7 +10,9 @@
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
6 changes: 4 additions & 2 deletions configs/_base_/models/faster_rcnn_r50_caffe_c4.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@
model = dict(
detector=dict(
type='FasterRCNN',
pretrained='open-mmlab://detectron2/resnet50_caffe',
backbone=dict(
type='ResNet',
depth=50,
@@ -14,7 +13,10 @@
frozen_stages=1,
norm_cfg=norm_cfg,
norm_eval=True,
style='caffe'),
style='caffe',
init_cfg=dict(
type='Pretrained',
checkpoint='open-mmlab://detectron2/resnet50_caffe')),
rpn_head=dict(
type='RPNHead',
in_channels=1024,
5 changes: 3 additions & 2 deletions configs/_base_/models/retinanet_r50_fpn.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,6 @@
model = dict(
detector=dict(
type='RetinaNet',
pretrained='torchvision://resnet50',
backbone=dict(
type='ResNet',
depth=50,
@@ -11,7 +10,9 @@
frozen_stages=1,
norm_cfg=dict(type='BN', requires_grad=True),
norm_eval=True,
style='pytorch'),
style='pytorch',
init_cfg=dict(
type='Pretrained', checkpoint='torchvision://resnet50')),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
_base_ = [
'../mot/tracktor/tracktor_faster-rcnn_r50_fpn_4e_mot17-private-half.py'
]

model = dict(
pretrains=dict(
detector= # noqa: E251
'https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436-f4ba7d61.pth', # noqa: E501
reid= # noqa: E251
'https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055-4747ee95.pth' # noqa: E501
))
detector=dict(
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmtracking/fp16/faster-rcnn_r50_fpn_fp16_4e_mot17-half_20210730_002436-f4ba7d61.pth' # noqa: E501
)),
reid=dict(
init_cfg=dict(
type='Pretrained',
checkpoint= # noqa: E251
'https://download.openmmlab.com/mmtracking/fp16/reid_r50_fp16_8x32_6e_mot17_20210731_033055-4747ee95.pth' # noqa: E501
)))
fp16 = dict(loss_scale=512.)
10 changes: 6 additions & 4 deletions mmtrack/models/mot/trackers/base_tracker.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from abc import ABCMeta, abstractmethod

import torch
import torch.nn as nn
import torch.nn.functional as F
from addict import Dict
from mmcv.runner import BaseModule

from mmtrack.models import TRACKERS


@TRACKERS.register_module()
class BaseTracker(nn.Module, metaclass=ABCMeta):
class BaseTracker(BaseModule, metaclass=ABCMeta):
"""Base tracker model.

Args:
@@ -18,10 +18,12 @@ class BaseTracker(nn.Module, metaclass=ABCMeta):
indicates the momentum. Default to None.
num_frames_retain (int, optional). If a track is disappeared more than
`num_frames_retain` frames, it will be deleted in the memo.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""

def __init__(self, momentums=None, num_frames_retain=10):
super().__init__()
def __init__(self, momentums=None, num_frames_retain=10, init_cfg=None):
super().__init__(init_cfg)
if momentums is not None:
assert isinstance(momentums, dict), 'momentums must be a dict'
self.momentums = momentums
5 changes: 4 additions & 1 deletion mmtrack/models/mot/trackers/sort_tracker.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,8 @@ class SortTracker(BaseTracker):
Defaults to 0.7.
num_tentatives (int, optional): Number of continuous frames to confirm
a track. Defaults to 3.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""

def __init__(self,
@@ -41,8 +43,9 @@ def __init__(self,
match_score_thr=2.0),
match_iou_thr=0.7,
num_tentatives=3,
init_cfg=None,
**kwargs):
super().__init__(**kwargs)
super().__init__(init_cfg=init_cfg, **kwargs)
self.obj_score_thr = obj_score_thr
self.reid = reid
self.match_iou_thr = match_iou_thr
5 changes: 4 additions & 1 deletion mmtrack/models/mot/trackers/tracktor_tracker.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,8 @@ class TracktorTracker(BaseTracker):
matching process. Default to 2.0.
- match_iou_thr (float, optional): Minimum IoU when matching
objects with embedding similarity. Default to 0.2.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""

def __init__(self,
@@ -49,8 +51,9 @@ def __init__(self,
img_norm_cfg=None,
match_score_thr=2.0,
match_iou_thr=0.2),
init_cfg=None,
**kwargs):
super().__init__(**kwargs)
super().__init__(init_cfg=init_cfg, **kwargs)
self.obj_score_thr = obj_score_thr
self.regression = regression
self.reid = reid
2 changes: 2 additions & 0 deletions tools/benchmark.py
Original file line number Diff line number Diff line change
@@ -57,6 +57,8 @@ def main():

# build the model and load checkpoint
model = build_model(cfg.model)
# We need call `init_weights()` to load pretained weights in MOT task.
model.init_weights()
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)
2 changes: 2 additions & 0 deletions tools/mot_param_search.py
Original file line number Diff line number Diff line change
@@ -156,6 +156,8 @@ def main():
cfg.model, train_cfg=cfg.train_cfg, test_cfg=cfg.test_cfg)
else:
model = build_model(cfg.model)
# We need call `init_weights()` to load pretained weights in MOT task.
model.init_weights()
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
wrap_fp16_model(model)