Skip to content

Commit

Permalink
[Feature]support pointpillar nus version (open-mmlab#391)
Browse files Browse the repository at this point in the history
* support pointpillar nus version

* support pointpillar nus version

* add regression test config for mmdet3d

* fix exit with no error code

* fix cfg

* fix worksize

* fix worksize

* fix cfg

* support nus pp

* fix yaml

* fix yaml

* fix yaml

* add ut

* fix ut

Co-authored-by: RunningLeon <[email protected]>
  • Loading branch information
VVsssssk and RunningLeon authored Aug 5, 2022
1 parent 80d24fc commit f957284
Show file tree
Hide file tree
Showing 14 changed files with 319 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/openvino.py']

onnx_config = dict(input_shape=None)

backend_config = dict(model_inputs=[
dict(
opt_shapes=dict(
voxels=[20000, 64, 4], num_points=[20000], coors=[20000, 4]))
])
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/tensorrt.py']
backend_config = dict(
common_config=dict(max_workspace_size=1 << 30),
common_config=dict(max_workspace_size=1 << 32),
model_inputs=[
dict(
input_shapes=dict(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
_base_ = ['./voxel-detection_dynamic.py', '../../_base_/backends/tensorrt.py']
backend_config = dict(
common_config=dict(max_workspace_size=1 << 32),
model_inputs=[
dict(
input_shapes=dict(
voxels=dict(
min_shape=[5000, 64, 4],
opt_shape=[20000, 64, 4],
max_shape=[30000, 64, 4]),
num_points=dict(
min_shape=[5000], opt_shape=[20000], max_shape=[30000]),
coors=dict(
min_shape=[5000, 4],
opt_shape=[20000, 4],
max_shape=[30000, 4]),
))
])
4 changes: 3 additions & 1 deletion mmdeploy/codebase/mmdet3d/deploy/voxel_detection_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def post_process(model_cfg: Union[str, mmcv.Config],
else:
raise NotImplementedError('Not supported model.')
head_cfg['train_cfg'] = None
head_cfg['test_cfg'] = model_cfg.model['test_cfg']
head_cfg['test_cfg'] = model_cfg.model['test_cfg']\
if 'pts' not in model_cfg.model['test_cfg'].keys()\
else model_cfg.model['test_cfg']['pts']
head = build_head(head_cfg)
if device == 'cpu':
logger = get_root_logger()
Expand Down
3 changes: 2 additions & 1 deletion mmdeploy/codebase/mmdet3d/models/centerpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def centerpoint__get_bbox(ctx,
scores_range = [0]
bbox_range = [0]
dir_range = [0]
self.test_cfg = self.test_cfg['pts']
for i, task_head in enumerate(self.task_heads):
scores_range.append(scores_range[i] + self.num_classes[i])
bbox_range.append(bbox_range[i] + 8)
Expand Down Expand Up @@ -135,6 +134,8 @@ def centerpoint__get_bbox(ctx,
batch_vel,
reg=batch_reg,
task_id=task_id)
if 'pts' in self.test_cfg.keys():
self.test_cfg = self.test_cfg.pts
assert self.test_cfg['nms_type'] in ['circle', 'rotate']
batch_reg_preds = [box['bboxes'] for box in temp]
batch_cls_preds = [box['scores'] for box in temp]
Expand Down
52 changes: 52 additions & 0 deletions mmdeploy/codebase/mmdet3d/models/mvx_two_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,55 @@ def mvxtwostagedetector__extract_feat(ctx, self, voxels, num_points, coors,
pts_feats = self.extract_pts_feat(voxels, num_points, coors, img_feats,
img_metas)
return (img_feats, pts_feats)


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.'
'extract_pts_feat')
def mvxtwostagedetector__extract_pts_feat(ctx, self, voxels, num_points, coors,
img_feats, img_metas):
"""Extract features from points. Rewrite this func to remove voxelize op.
Args:
voxels (torch.Tensor): Point features or raw points in shape (N, M, C).
num_points (torch.Tensor): Number of points in each voxel.
coors (torch.Tensor): Coordinates of each voxel.
img_feats (list[torch.Tensor], optional): Image features used for
multi-modality fusion. Defaults to None.
img_metas (list[dict]): Meta information of samples.
Returns:
torch.Tensor: Points feature.
"""
if not self.with_pts_bbox:
return None
voxel_features = self.pts_voxel_encoder(voxels, num_points, coors,
img_feats, img_metas)
batch_size = coors[-1, 0] + 1
x = self.pts_middle_encoder(voxel_features, coors, batch_size)
x = self.pts_backbone(x)
if self.with_pts_neck:
x = self.pts_neck(x)
return x


@FUNCTION_REWRITER.register_rewriter(
'mmdet3d.models.detectors.mvx_two_stage.MVXTwoStageDetector.'
'simple_test_pts')
def mvxtwostagedetector__simple_test_pts(ctx,
self,
x,
img_metas,
rescale=False):
"""Rewrite this func to format model outputs.
Args:
x (torch.Tensor): Input points feature.
img_metas (list[dict]): Meta information of samples.
rescale (bool): Whether need rescale.
Returns:
List: Result of model.
"""
bbox_preds, scores, dir_scores = self.pts_bbox_head(x)
return bbox_preds, scores, dir_scores
Binary file not shown.
106 changes: 106 additions & 0 deletions tests/regression/mmdet3d.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
globals:
codebase_dir: ../mmdetection3d
checkpoint_force_download: False
images:
kitti_input: &kitti_input ../mmdetection3d/demo/data/kitti/kitti_000008.bin
nus_input: &nus_input ./tests/data/n008-2018-08-01-15-16-36-0400__LIDAR_TOP__1533151612397179.pcd.bin

metric_info: &metric_info
AP: # named after metafile.Results.Metrics
eval_name: bbox # test.py --metrics args
metric_key: bbox_mAP # eval OrderedDict key name
tolerance: 1 # metric ±n%
task_name: 3D Object Detection # metafile.Results.Task
dataset: KITTI # metafile.Results.Dataset
mAP:
eval_name: bbox
metric_key: bbox_mAP
tolerance: 1 # metric ±n%
task_name: 3D Object Detection
dataset: nuScenes
NDS:
eval_name: bbox
metric_key: bbox_mAP
tolerance: 1 # metric ±n%
task_name: 3D Object Detection
dataset: nuScenes
backend_test: &default_backend_test False

convert_image: &convert_image
input_img: *kitti_input
test_img: *kitti_input


convert_image_nus: &convert_image_nus
input_img: *nus_input
test_img: *nus_input

onnxruntime:
pipeline_ort_dynamic_kitti_fp32: &pipeline_ort_dynamic_kitti_fp32
convert_image: *convert_image
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime_dynamic.py

pipeline_ort_dynamic_nus_fp32: &pipeline_ort_dynamic_nus_fp32
convert_image: *convert_image_nus
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_onnxruntime_dynamic.py

tensorrt:
pipeline_trt_dynamic_nus_fp32_64x4: &pipeline_trt_dynamic_nus_fp32_64x4
convert_image: *convert_image_nus
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic-nus-64x4.py

pipeline_trt_dynamic_nus_fp32_20x5: &pipeline_trt_dynamic_nus_fp32_20x5
convert_image: *convert_image_nus
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic-nus-20x5.py

pipeline_trt_dynamic_kitti_fp32: &pipeline_trt_dynamic_kitti_fp32
convert_image: *convert_image
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_tensorrt_dynamic-kitti-32x4.py

openvino:
pipeline_openvino_dynamic_kitti_fp32: &pipeline_openvino_dynamic_kitti_fp32
convert_image: *convert_image
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_openvino_dynamic-kitti-32x4.py

pipeline_openvino_dynamic_nus_fp32_64x4: &pipeline_openvino_dynamic_nus_fp32_64x4
convert_image: *convert_image_nus
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_openvino_dynamic-nus-64x4.py

pipeline_openvino_dynamic_nus_fp32_20x5: &pipeline_openvino_dynamic_nus_fp32_20x5
convert_image: *convert_image_nus
backend_test: *default_backend_test
deploy_config: configs/mmdet3d/voxel-detection/voxel-detection_openvino_dynamic-nus-20x5.py

models:
- name: PointPillars
metafile: configs/pointpillars/metafile.yml
model_configs:
- configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-3class.py
- configs/pointpillars/hv_pointpillars_secfpn_6x8_160e_kitti-3d-car.py
pipelines:
- *pipeline_ort_dynamic_kitti_fp32
- *pipeline_openvino_dynamic_kitti_fp32
- *pipeline_trt_dynamic_kitti_fp32
- name: PointPillars
metafile: configs/pointpillars/metafile.yml
model_configs:
- configs/pointpillars/hv_pointpillars_secfpn_sbn-all_4x8_2x_nus-3d.py
pipelines:
- *pipeline_ort_dynamic_nus_fp32
- *pipeline_openvino_dynamic_nus_fp32_64x4
- *pipeline_trt_dynamic_nus_fp32_64x4
- name: CenterPoint
metafile: configs/centerpoint/metafile.yml
model_configs:
- configs/centerpoint/centerpoint_02pillar_second_secfpn_circlenms_4x8_cyclic_20e_nus.py
pipelines:
- *pipeline_ort_dynamic_nus_fp32
- *pipeline_openvino_dynamic_nus_fp32_20x5
- *pipeline_trt_dynamic_nus_fp32_20x5
89 changes: 89 additions & 0 deletions tests/test_codebase/test_mmdet3d/data/model_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,92 @@
pre_max_size=1000,
post_max_size=83,
nms_thr=0.2)))
voxel_size = [0.25, 0.25, 8]
pointpillars_nus_model = dict(
pts_voxel_layer=dict(
max_num_points=64,
point_cloud_range=[-50, -50, -5, 50, 50, 3],
voxel_size=voxel_size,
max_voxels=(30000, 40000)),
pts_voxel_encoder=dict(
type='HardVFE',
in_channels=4,
feat_channels=[64, 64],
with_distance=False,
voxel_size=voxel_size,
with_cluster_center=True,
with_voxel_center=True,
point_cloud_range=[-50, -50, -5, 50, 50, 3],
norm_cfg=dict(type='naiveSyncBN1d', eps=1e-3, momentum=0.01)),
pts_middle_encoder=dict(
type='PointPillarsScatter', in_channels=64, output_shape=[400, 400]),
pts_backbone=dict(
type='SECOND',
in_channels=64,
norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01),
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
out_channels=[64, 128, 256]),
pts_neck=dict(
type='FPN',
norm_cfg=dict(type='naiveSyncBN2d', eps=1e-3, momentum=0.01),
act_cfg=dict(type='ReLU'),
in_channels=[64, 128, 256],
out_channels=256,
start_level=0,
num_outs=3),
pts_bbox_head=dict(
type='Anchor3DHead',
num_classes=10,
in_channels=256,
feat_channels=256,
use_direction_classifier=True,
anchor_generator=dict(
type='AlignedAnchor3DRangeGenerator',
ranges=[[-50, -50, -1.8, 50, 50, -1.8]],
scales=[1, 2, 4],
sizes=[
[2.5981, 0.8660, 1.], # 1.5 / sqrt(3)
[1.7321, 0.5774, 1.], # 1 / sqrt(3)
[1., 1., 1.],
[0.4, 0.4, 1],
],
custom_values=[0, 0],
rotations=[0, 1.57],
reshape_out=True),
assigner_per_size=False,
diff_rad_by_sin=True,
dir_offset=-0.7854, # -pi / 4
bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder', code_size=9),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
loss_dir=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.2)),
# model training and testing settings
train_cfg=dict(
pts=dict(
assigner=dict(
type='MaxIoUAssigner',
iou_calculator=dict(type='BboxOverlapsNearest3D'),
pos_iou_thr=0.6,
neg_iou_thr=0.3,
min_pos_iou=0.3,
ignore_iof_thr=-1),
allowed_border=0,
code_weight=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2],
pos_weight=-1,
debug=False)),
test_cfg=dict(
pts=dict(
use_rotate_nms=True,
nms_across_levels=False,
nms_pre=1000,
nms_thr=0.2,
score_thr=0.05,
min_bbox_size=0,
max_num=500)))
37 changes: 37 additions & 0 deletions tests/test_codebase/test_mmdet3d/test_mmdet3d_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,40 @@ def test_centerpoint(backend_type: Backend):
rewrite_outputs = head.get_bboxes(*[[i] for i in outputs],
inputs['img_metas'][0])
assert rewrite_outputs is not None


def get_pointpillars_nus():
from mmdet3d.models.detectors import MVXFasterRCNN

model = MVXFasterRCNN(**model_cfg.pointpillars_nus_model)
model.requires_grad_(False)
return model


@pytest.mark.parametrize('backend_type', [Backend.ONNXRUNTIME])
def test_pointpillars_nus(backend_type: Backend):
from mmdeploy.codebase.mmdet3d.deploy.voxel_detection import VoxelDetection
from mmdeploy.core import RewriterContext
check_backend(backend_type, True)
model = get_pointpillars_nus()
model.cpu().eval()
deploy_cfg = mmcv.Config(
dict(
backend_config=dict(type=backend_type.value),
onnx_config=dict(
input_shape=None,
opset_version=11,
input_names=['voxels', 'num_points', 'coors'],
output_names=['outputs']),
codebase_config=dict(
type=Codebase.MMDET3D.value, task=Task.VOXEL_DETECTION.value)))
voxeldetection = VoxelDetection(model_cfg, deploy_cfg, 'cpu')
inputs, data = voxeldetection.create_input(
'tests/test_codebase/test_mmdet3d/data/kitti/kitti_000008.bin')

with RewriterContext(
cfg=deploy_cfg,
backend=deploy_cfg.backend_config.type,
opset=deploy_cfg.onnx_config.opset_version):
outputs = model.forward(*data)
assert outputs is not None
3 changes: 2 additions & 1 deletion tools/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def parse_args():
nargs='+',
help='regression test yaml path.',
default=[
'mmcls', 'mmdet', 'mmseg', 'mmpose', 'mmocr', 'mmedit', 'mmrotate'
'mmcls', 'mmdet', 'mmseg', 'mmpose', 'mmocr', 'mmedit', 'mmrotate',
'mmdet3d'
])
parser.add_argument(
'-p',
Expand Down

0 comments on commit f957284

Please sign in to comment.