Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support MultiModalityDet3DInferencer #2342

Merged
merged 17 commits into from
Mar 28, 2023
Merged
6 changes: 6 additions & 0 deletions configs/imvoxelnet/imvoxelnet_8xb4_kitti-3d-car.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
modality=input_modality,
test_mode=False,
metainfo=metainfo,
box_type_3d='LiDAR',
backend_args=backend_args)))
val_dataloader = dict(
batch_size=1,
Expand All @@ -135,6 +136,7 @@
modality=input_modality,
test_mode=True,
metainfo=metainfo,
box_type_3d='LiDAR',
backend_args=backend_args))
test_dataloader = val_dataloader

Expand Down Expand Up @@ -168,3 +170,7 @@

# runtime
find_unused_parameters = True # only 1 of 4 FPN outputs is used

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
1 change: 1 addition & 0 deletions configs/mvxnet/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Collections:

Models:
- Name: dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class
Alias: mvxnet_kitti-3class
In Collection: MVX-Net
Config: configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py
Metadata:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,9 @@
type='KittiMetric', ann_file='data/kitti/kitti_infos_val.pkl')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# You may need to download the model first is the network is unstable
load_from = 'https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth' # noqa
6 changes: 4 additions & 2 deletions mmdet3d/apis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
inference_multi_modality_detector, inference_segmentor,
init_model)
from .inferencers import (Base3DInferencer, LidarDet3DInferencer,
LidarSeg3DInferencer, MonoDet3DInferencer)
LidarSeg3DInferencer, MonoDet3DInferencer,
MultiModalityDet3DInferencer)

__all__ = [
'inference_detector', 'init_model', 'inference_mono_3d_detector',
'convert_SyncBN', 'inference_multi_modality_detector',
'inference_segmentor', 'Base3DInferencer', 'MonoDet3DInferencer',
'LidarDet3DInferencer', 'LidarSeg3DInferencer'
'LidarDet3DInferencer', 'LidarSeg3DInferencer',
'MultiModalityDet3DInferencer'
]
18 changes: 12 additions & 6 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,10 @@ def inference_multi_modality_detector(model: nn.Module,
pcds: Union[str, Sequence[str]],
imgs: Union[str, Sequence[str]],
ann_file: Union[str, Sequence[str]],
cam_type: str = 'CAM_FRONT'):
"""Inference point cloud with the multi-modality detector.
cam_type: str = 'CAM2'):
"""Inference point cloud with the multi-modality detector. Now we only
support multi-modality detector for KITTI dataset since the multi-view
image loading is not supported yet in this inference function.

Args:
model (nn.Module): The loaded detector.
Expand All @@ -187,7 +189,7 @@ def inference_multi_modality_detector(model: nn.Module,
Either image files or loaded images.
ann_file (str, Sequence[str]): Annotation files.
cam_type (str): Image of Camera chose to infer.
For kitti dataset, it should be 'CAM_2',
For kitti dataset, it should be 'CAM2',
and for nuscenes dataset, it should be
'CAM_FRONT'. Defaults to 'CAM_FRONT'.

Expand Down Expand Up @@ -216,7 +218,6 @@ def inference_multi_modality_detector(model: nn.Module,
get_box_type(cfg.test_dataloader.dataset.box_type_3d)

data_list = mmengine.load(ann_file)['data_list']
assert len(imgs) == len(data_list)

data = []
for index, pcd in enumerate(pcds):
Expand All @@ -228,13 +229,18 @@ def inference_multi_modality_detector(model: nn.Module,
if osp.basename(img_path) != osp.basename(img):
raise ValueError(f'the info file of {img_path} is not provided.')

data_info['images'][cam_type]['img_path'] = img
cam2img = np.array(data_info['images'][cam_type]['cam2img'])

# TODO: check the name consistency of
# image file and point cloud file
# TODO: support multi-view image loading
data_ = dict(
lidar_points=dict(lidar_path=pcd),
img_path=img,
box_type_3d=box_type_3d,
box_mode_3d=box_mode_3d)
box_mode_3d=box_mode_3d,
cam2img=cam2img)

# LiDAR to image conversion for KITTI dataset
if box_mode_3d == Box3DMode.LIDAR:
Expand Down Expand Up @@ -295,7 +301,7 @@ def inference_mono_3d_detector(model: nn.Module,
box_type_3d, box_mode_3d = \
get_box_type(cfg.test_dataloader.dataset.box_type_3d)

data_list = mmengine.load(ann_file)
data_list = mmengine.load(ann_file)['data_list']
assert len(imgs) == len(data_list)

data = []
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from .lidar_det3d_inferencer import LidarDet3DInferencer
from .lidar_seg3d_inferencer import LidarSeg3DInferencer
from .mono_det3d_inferencer import MonoDet3DInferencer
from .multi_modality_det3d_inferencer import MultiModalityDet3DInferencer

__all__ = [
'Base3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer',
'LidarSeg3DInferencer'
'LidarSeg3DInferencer', 'MultiModalityDet3DInferencer'
]
10 changes: 5 additions & 5 deletions mmdet3d/apis/inferencers/lidar_det3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,19 +93,19 @@ def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline

load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadPointsFromFile')
if load_img_idx == -1:
load_point_idx = self._get_transform_idx(pipeline_cfg,
'LoadPointsFromFile')
if load_point_idx == -1:
raise ValueError(
'LoadPointsFromFile is not found in the test pipeline')

load_cfg = pipeline_cfg[load_img_idx]
load_cfg = pipeline_cfg[load_point_idx]
self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[
'load_dim']
self.use_dim = list(range(load_cfg['use_dim'])) if isinstance(
load_cfg['use_dim'], int) else load_cfg['use_dim']

pipeline_cfg[load_img_idx]['type'] = 'LidarDet3DInferencerLoader'
pipeline_cfg[load_point_idx]['type'] = 'LidarDet3DInferencerLoader'
return Compose(pipeline_cfg)

def visualize(self,
Expand Down
226 changes: 226 additions & 0 deletions mmdet3d/apis/inferencers/multi_modality_det3d_inferencer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import Dict, List, Optional, Sequence, Union

import mmcv
import mmengine
import numpy as np
from mmengine.dataset import Compose
from mmengine.infer.infer import ModelType
from mmengine.structures import InstanceData

from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType
from .base_3d_inferencer import Base3DInferencer

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
InputsType = Union[InputType, Sequence[InputType]]
PredType = Union[InstanceData, InstanceList]
ImgType = Union[np.ndarray, Sequence[np.ndarray]]
ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]


@INFERENCERS.register_module(name='det3d-multi_modality')
@INFERENCERS.register_module()
class MultiModalityDet3DInferencer(Base3DInferencer):
"""The inferencer of multi-modality detection.

Args:
model (str, optional): Path to the config file or the model name
defined in metafile. For example, it could be
"pointpillars_kitti-3class" or
"configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py". # noqa: E501
If model is not specified, user must provide the
`weights` saved by MMEngine which contains the config string.
Defaults to None.
weights (str, optional): Path to the checkpoint. If it is not specified
and model is a model name of metafile, the weights will be loaded
from metafile. Defaults to None.
device (str, optional): Device to run inference. If None, the available
device will be automatically used. Defaults to None.
scope (str): The scope of registry. Defaults to 'mmdet3d'.
palette (str): The palette of visualization. Defaults to 'none'.
"""

preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
}

def __init__(self,
model: Union[ModelType, str, None] = None,
weights: Optional[str] = None,
device: Optional[str] = None,
scope: str = 'mmdet3d',
palette: str = 'none') -> None:
# A global counter tracking the number of frames processed, for
# naming of the output results
self.num_visualized_frames = 0
super(MultiModalityDet3DInferencer, self).__init__(
model=model,
weights=weights,
device=device,
scope=scope,
palette=palette)

def _inputs_to_list(self, inputs: Union[dict, list]) -> list:
"""Preprocess the inputs to a list.

Preprocess inputs to a list according to its type:

- list or tuple: return inputs
- dict: the value with key 'points' is
- Directory path: return all files in the directory
- other cases: return a list containing the string. The string
could be a path to file, a url or other types of string according
to the task.

Args:
inputs (Union[dict, list]): Inputs for the inferencer.

Returns:
list: List of input for the :meth:`preprocess`.
"""
return super()._inputs_to_list(inputs, modality_key=['points', 'img'])

def _init_pipeline(self, cfg: ConfigType) -> Compose:
"""Initialize the test pipeline."""
pipeline_cfg = cfg.test_dataloader.dataset.pipeline

load_point_idx = self._get_transform_idx(pipeline_cfg,
'LoadPointsFromFile')
# Now, we only support ``LoadImageFromFile`` as the image loader in the
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
# original piepline. `LoadMultiViewImageFromFiles` is not supported
# yet.
load_img_idx = self._get_transform_idx(pipeline_cfg,
'LoadImageFromFile')

if load_point_idx == -1 or load_img_idx == -1:
raise ValueError(
'Both LoadPointsFromFile and LoadImageFromFile must '
'be specified the pipeline, but LoadPointsFromFile is '
f'{load_point_idx == -1} and LoadImageFromFile is '
f'{load_img_idx}')

load_cfg = pipeline_cfg[load_point_idx]
self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[
'load_dim']
self.use_dim = list(range(load_cfg['use_dim'])) if isinstance(
load_cfg['use_dim'], int) else load_cfg['use_dim']

load_point_args = pipeline_cfg[load_point_idx]
load_point_args.pop('type')
load_img_args = pipeline_cfg[load_img_idx]
load_img_args.pop('type')

load_idx = min(load_point_idx, load_img_idx)
pipeline_cfg.pop(max(load_point_idx, load_img_idx))

pipeline_cfg[load_idx] = dict(
type='MultiModalityDet3DInferencerLoader',
load_point_args=load_point_args,
load_img_args=load_img_args)

return Compose(pipeline_cfg)

def visualize(self,
inputs: InputsType,
preds: PredType,
return_vis: bool = False,
show: bool = False,
wait_time: int = 0,
draw_pred: bool = True,
pred_score_thr: float = 0.3,
img_out_dir: str = '') -> Union[List[np.ndarray], None]:
"""Visualize predictions.

Args:
inputs (InputsType): Inputs for the inferencer.
preds (PredType): Predictions of the model.
return_vis (bool): Whether to return the visualization result.
Defaults to False.
show (bool): Whether to display the image in a popup window.
Defaults to False.
wait_time (float): The interval of show (s). Defaults to 0.
draw_pred (bool): Whether to draw predicted bounding boxes.
Defaults to True.
pred_score_thr (float): Minimum score of bboxes to draw.
Defaults to 0.3.
img_out_dir (str): Output directory of visualization results.
If left as empty, no file will be saved. Defaults to ''.

Returns:
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
List[np.ndarray] or None: Returns visualization results only if
applicable.
"""
if self.visualizer is None or (not show and img_out_dir == ''
and not return_vis):
return None

if getattr(self, 'visualizer') is None:
raise ValueError('Visualization needs the "visualizer" term'
'defined in the config, but got None.')

results = []

for single_input, pred in zip(inputs, preds):
points_input = single_input['points']
if isinstance(points_input, str):
pts_bytes = mmengine.fileio.get(points_input)
points = np.frombuffer(pts_bytes, dtype=np.float32)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
pc_name = osp.basename(points_input).split('.bin')[0]
pc_name = f'{pc_name}.png'
elif isinstance(points_input, np.ndarray):
points = points_input.copy()
pc_num = str(self.num_visualized_frames).zfill(8)
pc_name = f'pc_{pc_num}.png'
else:
raise ValueError('Unsupported input type: '
f'{type(points_input)}')

o3d_save_path = osp.join(img_out_dir, pc_name) \
if img_out_dir != '' else None

img_input = single_input['img']
if isinstance(single_input['img'], str):
img_bytes = mmengine.fileio.get(img_input)
img = mmcv.imfrombytes(img_bytes)
img = img[:, :, ::-1]
img_name = osp.basename(img_input)
elif isinstance(img_input, np.ndarray):
img = img_input.copy()
img_num = str(self.num_visualized_frames).zfill(8)
img_name = f'{img_num}.jpg'
else:
raise ValueError('Unsupported input type: '
f'{type(img_input)}')

out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \
else None

data_input = dict(points=points, img=img)
self.visualizer.add_datasample(
pc_name,
data_input,
pred,
show=show,
wait_time=wait_time,
draw_gt=False,
draw_pred=draw_pred,
pred_score_thr=pred_score_thr,
o3d_save_path=o3d_save_path,
out_file=out_file,
vis_task='multi-modality_det',
)
results.append(points)
self.num_visualized_frames += 1

return results
6 changes: 4 additions & 2 deletions mmdet3d/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromDict,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
MonoDet3DInferencerLoader, NormalizePointsColor,
MonoDet3DInferencerLoader,
MultiModalityDet3DInferencerLoader, NormalizePointsColor,
PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
# yapf: disable
Expand All @@ -30,5 +31,6 @@
'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize',
'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D',
'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader',
'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix'
'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix',
'MultiModalityDet3DInferencerLoader'
]
Loading