diff --git a/configs/imvoxelnet/imvoxelnet_8xb4_kitti-3d-car.py b/configs/imvoxelnet/imvoxelnet_8xb4_kitti-3d-car.py index 49b6eeb732..df1e9d69ca 100644 --- a/configs/imvoxelnet/imvoxelnet_8xb4_kitti-3d-car.py +++ b/configs/imvoxelnet/imvoxelnet_8xb4_kitti-3d-car.py @@ -119,6 +119,7 @@ modality=input_modality, test_mode=False, metainfo=metainfo, + box_type_3d='LiDAR', backend_args=backend_args))) val_dataloader = dict( batch_size=1, @@ -135,6 +136,7 @@ modality=input_modality, test_mode=True, metainfo=metainfo, + box_type_3d='LiDAR', backend_args=backend_args)) test_dataloader = val_dataloader @@ -168,3 +170,7 @@ # runtime find_unused_parameters = True # only 1 of 4 FPN outputs is used + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') diff --git a/configs/mvxnet/metafile.yml b/configs/mvxnet/metafile.yml index c4d486d3cc..6eb341ab23 100644 --- a/configs/mvxnet/metafile.yml +++ b/configs/mvxnet/metafile.yml @@ -18,6 +18,7 @@ Collections: Models: - Name: dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class + Alias: mvxnet_kitti-3class In Collection: MVX-Net Config: configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py Metadata: diff --git a/configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py b/configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py index ea2be54dd6..feceb17c75 100644 --- a/configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py +++ b/configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py @@ -263,5 +263,9 @@ type='KittiMetric', ann_file='data/kitti/kitti_infos_val.pkl') test_evaluator = val_evaluator +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') + # You may need to download the model first is the network is unstable load_from = 'https://download.openmmlab.com/mmdetection3d/pretrain_models/mvx_faster_rcnn_detectron2-caffe_20e_coco-pretrain_gt-sample_kitti-3-class_moderate-79.3_20200207-a4a6a3c7.pth' # noqa diff --git a/mmdet3d/apis/__init__.py b/mmdet3d/apis/__init__.py index eb4724dc65..57f732f166 100644 --- a/mmdet3d/apis/__init__.py +++ b/mmdet3d/apis/__init__.py @@ -4,11 +4,13 @@ inference_multi_modality_detector, inference_segmentor, init_model) from .inferencers import (Base3DInferencer, LidarDet3DInferencer, - LidarSeg3DInferencer, MonoDet3DInferencer) + LidarSeg3DInferencer, MonoDet3DInferencer, + MultiModalityDet3DInferencer) __all__ = [ 'inference_detector', 'init_model', 'inference_mono_3d_detector', 'convert_SyncBN', 'inference_multi_modality_detector', 'inference_segmentor', 'Base3DInferencer', 'MonoDet3DInferencer', - 'LidarDet3DInferencer', 'LidarSeg3DInferencer' + 'LidarDet3DInferencer', 'LidarSeg3DInferencer', + 'MultiModalityDet3DInferencer' ] diff --git a/mmdet3d/apis/inference.py b/mmdet3d/apis/inference.py index 0e4dc88859..98b4373d07 100644 --- a/mmdet3d/apis/inference.py +++ b/mmdet3d/apis/inference.py @@ -176,8 +176,10 @@ def inference_multi_modality_detector(model: nn.Module, pcds: Union[str, Sequence[str]], imgs: Union[str, Sequence[str]], ann_file: Union[str, Sequence[str]], - cam_type: str = 'CAM_FRONT'): - """Inference point cloud with the multi-modality detector. + cam_type: str = 'CAM2'): + """Inference point cloud with the multi-modality detector. Now we only + support multi-modality detector for KITTI dataset since the multi-view + image loading is not supported yet in this inference function. Args: model (nn.Module): The loaded detector. @@ -187,7 +189,7 @@ def inference_multi_modality_detector(model: nn.Module, Either image files or loaded images. ann_file (str, Sequence[str]): Annotation files. cam_type (str): Image of Camera chose to infer. - For kitti dataset, it should be 'CAM_2', + For kitti dataset, it should be 'CAM2', and for nuscenes dataset, it should be 'CAM_FRONT'. Defaults to 'CAM_FRONT'. @@ -216,7 +218,6 @@ def inference_multi_modality_detector(model: nn.Module, get_box_type(cfg.test_dataloader.dataset.box_type_3d) data_list = mmengine.load(ann_file)['data_list'] - assert len(imgs) == len(data_list) data = [] for index, pcd in enumerate(pcds): @@ -228,13 +229,18 @@ def inference_multi_modality_detector(model: nn.Module, if osp.basename(img_path) != osp.basename(img): raise ValueError(f'the info file of {img_path} is not provided.') + data_info['images'][cam_type]['img_path'] = img + cam2img = np.array(data_info['images'][cam_type]['cam2img']) + # TODO: check the name consistency of # image file and point cloud file + # TODO: support multi-view image loading data_ = dict( lidar_points=dict(lidar_path=pcd), img_path=img, box_type_3d=box_type_3d, - box_mode_3d=box_mode_3d) + box_mode_3d=box_mode_3d, + cam2img=cam2img) # LiDAR to image conversion for KITTI dataset if box_mode_3d == Box3DMode.LIDAR: @@ -295,7 +301,7 @@ def inference_mono_3d_detector(model: nn.Module, box_type_3d, box_mode_3d = \ get_box_type(cfg.test_dataloader.dataset.box_type_3d) - data_list = mmengine.load(ann_file) + data_list = mmengine.load(ann_file)['data_list'] assert len(imgs) == len(data_list) data = [] diff --git a/mmdet3d/apis/inferencers/__init__.py b/mmdet3d/apis/inferencers/__init__.py index 5332a8370e..0da7b52a3e 100644 --- a/mmdet3d/apis/inferencers/__init__.py +++ b/mmdet3d/apis/inferencers/__init__.py @@ -3,8 +3,9 @@ from .lidar_det3d_inferencer import LidarDet3DInferencer from .lidar_seg3d_inferencer import LidarSeg3DInferencer from .mono_det3d_inferencer import MonoDet3DInferencer +from .multi_modality_det3d_inferencer import MultiModalityDet3DInferencer __all__ = [ 'Base3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer', - 'LidarSeg3DInferencer' + 'LidarSeg3DInferencer', 'MultiModalityDet3DInferencer' ] diff --git a/mmdet3d/apis/inferencers/lidar_det3d_inferencer.py b/mmdet3d/apis/inferencers/lidar_det3d_inferencer.py index ae04b6b03c..a3fdc479b7 100644 --- a/mmdet3d/apis/inferencers/lidar_det3d_inferencer.py +++ b/mmdet3d/apis/inferencers/lidar_det3d_inferencer.py @@ -93,19 +93,19 @@ def _init_pipeline(self, cfg: ConfigType) -> Compose: """Initialize the test pipeline.""" pipeline_cfg = cfg.test_dataloader.dataset.pipeline - load_img_idx = self._get_transform_idx(pipeline_cfg, - 'LoadPointsFromFile') - if load_img_idx == -1: + load_point_idx = self._get_transform_idx(pipeline_cfg, + 'LoadPointsFromFile') + if load_point_idx == -1: raise ValueError( 'LoadPointsFromFile is not found in the test pipeline') - load_cfg = pipeline_cfg[load_img_idx] + load_cfg = pipeline_cfg[load_point_idx] self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[ 'load_dim'] self.use_dim = list(range(load_cfg['use_dim'])) if isinstance( load_cfg['use_dim'], int) else load_cfg['use_dim'] - pipeline_cfg[load_img_idx]['type'] = 'LidarDet3DInferencerLoader' + pipeline_cfg[load_point_idx]['type'] = 'LidarDet3DInferencerLoader' return Compose(pipeline_cfg) def visualize(self, diff --git a/mmdet3d/apis/inferencers/multi_modality_det3d_inferencer.py b/mmdet3d/apis/inferencers/multi_modality_det3d_inferencer.py new file mode 100644 index 0000000000..ab02e064b8 --- /dev/null +++ b/mmdet3d/apis/inferencers/multi_modality_det3d_inferencer.py @@ -0,0 +1,233 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import warnings +from typing import Dict, List, Optional, Sequence, Union + +import mmcv +import mmengine +import numpy as np +from mmengine.dataset import Compose +from mmengine.infer.infer import ModelType +from mmengine.structures import InstanceData + +from mmdet3d.registry import INFERENCERS +from mmdet3d.utils import ConfigType +from .base_3d_inferencer import Base3DInferencer + +InstanceList = List[InstanceData] +InputType = Union[str, np.ndarray] +InputsType = Union[InputType, Sequence[InputType]] +PredType = Union[InstanceData, InstanceList] +ImgType = Union[np.ndarray, Sequence[np.ndarray]] +ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]] + + +@INFERENCERS.register_module(name='det3d-multi_modality') +@INFERENCERS.register_module() +class MultiModalityDet3DInferencer(Base3DInferencer): + """The inferencer of multi-modality detection. + + Args: + model (str, optional): Path to the config file or the model name + defined in metafile. For example, it could be + "pointpillars_kitti-3class" or + "configs/pointpillars/pointpillars_hv_secfpn_8xb6-160e_kitti-3d-3class.py". # noqa: E501 + If model is not specified, user must provide the + `weights` saved by MMEngine which contains the config string. + Defaults to None. + weights (str, optional): Path to the checkpoint. If it is not specified + and model is a model name of metafile, the weights will be loaded + from metafile. Defaults to None. + device (str, optional): Device to run inference. If None, the available + device will be automatically used. Defaults to None. + scope (str): The scope of registry. Defaults to 'mmdet3d'. + palette (str): The palette of visualization. Defaults to 'none'. + """ + + preprocess_kwargs: set = set() + forward_kwargs: set = set() + visualize_kwargs: set = { + 'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr', + 'img_out_dir' + } + postprocess_kwargs: set = { + 'print_result', 'pred_out_file', 'return_datasample' + } + + def __init__(self, + model: Union[ModelType, str, None] = None, + weights: Optional[str] = None, + device: Optional[str] = None, + scope: str = 'mmdet3d', + palette: str = 'none') -> None: + # A global counter tracking the number of frames processed, for + # naming of the output results + self.num_visualized_frames = 0 + super(MultiModalityDet3DInferencer, self).__init__( + model=model, + weights=weights, + device=device, + scope=scope, + palette=palette) + + def _inputs_to_list(self, inputs: Union[dict, list]) -> list: + """Preprocess the inputs to a list. + + Preprocess inputs to a list according to its type: + + - list or tuple: return inputs + - dict: the value with key 'points' is + - Directory path: return all files in the directory + - other cases: return a list containing the string. The string + could be a path to file, a url or other types of string according + to the task. + + Args: + inputs (Union[dict, list]): Inputs for the inferencer. + + Returns: + list: List of input for the :meth:`preprocess`. + """ + return super()._inputs_to_list(inputs, modality_key=['points', 'img']) + + def _init_pipeline(self, cfg: ConfigType) -> Compose: + """Initialize the test pipeline.""" + pipeline_cfg = cfg.test_dataloader.dataset.pipeline + + load_point_idx = self._get_transform_idx(pipeline_cfg, + 'LoadPointsFromFile') + load_mv_img_idx = self._get_transform_idx( + pipeline_cfg, 'LoadMultiViewImageFromFiles') + if load_mv_img_idx != -1: + warnings.warn( + 'LoadMultiViewImageFromFiles is not supported yet in the ' + 'multi-modality inferencer. Please remove it') + # Now, we only support ``LoadImageFromFile`` as the image loader in the + # original piepline. `LoadMultiViewImageFromFiles` is not supported + # yet. + load_img_idx = self._get_transform_idx(pipeline_cfg, + 'LoadImageFromFile') + + if load_point_idx == -1 or load_img_idx == -1: + raise ValueError( + 'Both LoadPointsFromFile and LoadImageFromFile must ' + 'be specified the pipeline, but LoadPointsFromFile is ' + f'{load_point_idx == -1} and LoadImageFromFile is ' + f'{load_img_idx}') + + load_cfg = pipeline_cfg[load_point_idx] + self.coord_type, self.load_dim = load_cfg['coord_type'], load_cfg[ + 'load_dim'] + self.use_dim = list(range(load_cfg['use_dim'])) if isinstance( + load_cfg['use_dim'], int) else load_cfg['use_dim'] + + load_point_args = pipeline_cfg[load_point_idx] + load_point_args.pop('type') + load_img_args = pipeline_cfg[load_img_idx] + load_img_args.pop('type') + + load_idx = min(load_point_idx, load_img_idx) + pipeline_cfg.pop(max(load_point_idx, load_img_idx)) + + pipeline_cfg[load_idx] = dict( + type='MultiModalityDet3DInferencerLoader', + load_point_args=load_point_args, + load_img_args=load_img_args) + + return Compose(pipeline_cfg) + + def visualize(self, + inputs: InputsType, + preds: PredType, + return_vis: bool = False, + show: bool = False, + wait_time: int = 0, + draw_pred: bool = True, + pred_score_thr: float = 0.3, + img_out_dir: str = '') -> Union[List[np.ndarray], None]: + """Visualize predictions. + + Args: + inputs (InputsType): Inputs for the inferencer. + preds (PredType): Predictions of the model. + return_vis (bool): Whether to return the visualization result. + Defaults to False. + show (bool): Whether to display the image in a popup window. + Defaults to False. + wait_time (float): The interval of show (s). Defaults to 0. + draw_pred (bool): Whether to draw predicted bounding boxes. + Defaults to True. + pred_score_thr (float): Minimum score of bboxes to draw. + Defaults to 0.3. + img_out_dir (str): Output directory of visualization results. + If left as empty, no file will be saved. Defaults to ''. + + Returns: + List[np.ndarray] or None: Returns visualization results only if + applicable. + """ + if self.visualizer is None or (not show and img_out_dir == '' + and not return_vis): + return None + + if getattr(self, 'visualizer') is None: + raise ValueError('Visualization needs the "visualizer" term' + 'defined in the config, but got None.') + + results = [] + + for single_input, pred in zip(inputs, preds): + points_input = single_input['points'] + if isinstance(points_input, str): + pts_bytes = mmengine.fileio.get(points_input) + points = np.frombuffer(pts_bytes, dtype=np.float32) + points = points.reshape(-1, self.load_dim) + points = points[:, self.use_dim] + pc_name = osp.basename(points_input).split('.bin')[0] + pc_name = f'{pc_name}.png' + elif isinstance(points_input, np.ndarray): + points = points_input.copy() + pc_num = str(self.num_visualized_frames).zfill(8) + pc_name = f'pc_{pc_num}.png' + else: + raise ValueError('Unsupported input type: ' + f'{type(points_input)}') + + o3d_save_path = osp.join(img_out_dir, pc_name) \ + if img_out_dir != '' else None + + img_input = single_input['img'] + if isinstance(single_input['img'], str): + img_bytes = mmengine.fileio.get(img_input) + img = mmcv.imfrombytes(img_bytes) + img = img[:, :, ::-1] + img_name = osp.basename(img_input) + elif isinstance(img_input, np.ndarray): + img = img_input.copy() + img_num = str(self.num_visualized_frames).zfill(8) + img_name = f'{img_num}.jpg' + else: + raise ValueError('Unsupported input type: ' + f'{type(img_input)}') + + out_file = osp.join(img_out_dir, img_name) if img_out_dir != '' \ + else None + + data_input = dict(points=points, img=img) + self.visualizer.add_datasample( + pc_name, + data_input, + pred, + show=show, + wait_time=wait_time, + draw_gt=False, + draw_pred=draw_pred, + pred_score_thr=pred_score_thr, + o3d_save_path=o3d_save_path, + out_file=out_file, + vis_task='multi-modality_det', + ) + results.append(points) + self.num_visualized_frames += 1 + + return results diff --git a/mmdet3d/datasets/transforms/__init__.py b/mmdet3d/datasets/transforms/__init__.py index 09fbed659b..4c0587f80e 100644 --- a/mmdet3d/datasets/transforms/__init__.py +++ b/mmdet3d/datasets/transforms/__init__.py @@ -4,7 +4,8 @@ from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D, LoadMultiViewImageFromFiles, LoadPointsFromDict, LoadPointsFromFile, LoadPointsFromMultiSweeps, - MonoDet3DInferencerLoader, NormalizePointsColor, + MonoDet3DInferencerLoader, + MultiModalityDet3DInferencerLoader, NormalizePointsColor, PointSegClassMapping) from .test_time_aug import MultiScaleFlipAug3D # yapf: disable @@ -30,5 +31,6 @@ 'RandomDropPointsColor', 'RandomJitterPoints', 'AffineResize', 'RandomShiftScale', 'LoadPointsFromDict', 'Resize3D', 'RandomResize3D', 'MultiViewWrapper', 'PhotoMetricDistortion3D', 'MonoDet3DInferencerLoader', - 'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix' + 'LidarDet3DInferencerLoader', 'PolarMix', 'LaserMix', + 'MultiModalityDet3DInferencerLoader' ] diff --git a/mmdet3d/datasets/transforms/loading.py b/mmdet3d/datasets/transforms/loading.py index 129179538a..a240a1c032 100644 --- a/mmdet3d/datasets/transforms/loading.py +++ b/mmdet3d/datasets/transforms/loading.py @@ -1201,3 +1201,119 @@ def transform(self, single_input: dict) -> dict: if 'img' in inputs: return self.from_ndarray(inputs) return self.from_file(inputs) + + +@TRANSFORMS.register_module() +class MultiModalityDet3DInferencerLoader(BaseTransform): + """Load point cloud and image in the Inferencer's pipeline. + + Added keys: + - points + - img + - cam2img + - lidar2cam + - lidar2img + - timestamp + - axis_align_matrix + - box_type_3d + - box_mode_3d + """ + + def __init__(self, load_point_args: dict, load_img_args: dict) -> None: + super().__init__() + self.points_from_file = TRANSFORMS.build( + dict(type='LoadPointsFromFile', **load_point_args)) + self.points_from_ndarray = TRANSFORMS.build( + dict(type='LoadPointsFromDict', **load_point_args)) + coord_type = load_point_args['coord_type'] + self.box_type_3d, self.box_mode_3d = get_box_type(coord_type) + + self.imgs_from_file = TRANSFORMS.build( + dict(type='LoadImageFromFile', **load_img_args)) + self.imgs_from_ndarray = TRANSFORMS.build( + dict(type='LoadImageFromNDArray', **load_img_args)) + + def transform(self, single_input: dict) -> dict: + """Transform function to add image meta information. + Args: + single_input (dict): Single input. + + Returns: + dict: The dict contains loaded image, point cloud and meta + information. + """ + assert 'points' in single_input and 'img' in single_input and \ + 'calib' in single_input, "key 'points', 'img' and 'calib' must be " + f'in input dict, but got {single_input}' + if isinstance(single_input['points'], str): + inputs = dict( + lidar_points=dict(lidar_path=single_input['points']), + timestamp=1, + # for ScanNet demo we need axis_align_matrix + axis_align_matrix=np.eye(4), + box_type_3d=self.box_type_3d, + box_mode_3d=self.box_mode_3d) + elif isinstance(single_input['points'], np.ndarray): + inputs = dict( + points=single_input['points'], + timestamp=1, + # for ScanNet demo we need axis_align_matrix + axis_align_matrix=np.eye(4), + box_type_3d=self.box_type_3d, + box_mode_3d=self.box_mode_3d) + else: + raise ValueError('Unsupported input points type: ' + f"{type(single_input['points'])}") + + if 'points' in inputs: + points_inputs = self.points_from_ndarray(inputs) + else: + points_inputs = self.points_from_file(inputs) + + multi_modality_inputs = points_inputs + + box_type_3d, box_mode_3d = get_box_type('lidar') + if isinstance(single_input['calib'], str): + calib = mmengine.load(single_input['calib']) + + elif isinstance(single_input['calib'], dict): + calib = single_input['calib'] + else: + raise ValueError('Unsupported input calib type: ' + f"{type(single_input['calib'])}") + + cam2img = np.asarray(calib['cam2img'], dtype=np.float32) + lidar2cam = np.asarray(calib['lidar2cam'], dtype=np.float32) + if 'lidar2cam' in calib: + lidar2img = np.asarray(calib['lidar2img'], dtype=np.float32) + else: + lidar2img = cam2img @ lidar2cam + + if isinstance(single_input['img'], str): + inputs = dict( + img_path=single_input['img'], + cam2img=cam2img, + lidar2img=lidar2img, + lidar2cam=lidar2cam, + box_mode_3d=box_mode_3d, + box_type_3d=box_type_3d) + elif isinstance(single_input['img'], np.ndarray): + inputs = dict( + img=single_input['img'], + cam2img=cam2img, + lidar2img=lidar2img, + lidar2cam=lidar2cam, + box_type_3d=box_type_3d, + box_mode_3d=box_mode_3d) + else: + raise ValueError('Unsupported input image type: ' + f"{type(single_input['img'])}") + + if isinstance(single_input['img'], np.ndarray): + imgs_inputs = self.imgs_from_ndarray(inputs) + else: + imgs_inputs = self.imgs_from_file(inputs) + + multi_modality_inputs.update(imgs_inputs) + + return multi_modality_inputs diff --git a/mmdet3d/visualization/local_visualizer.py b/mmdet3d/visualization/local_visualizer.py index 5c36d37d91..aaa3ce9e39 100644 --- a/mmdet3d/visualization/local_visualizer.py +++ b/mmdet3d/visualization/local_visualizer.py @@ -822,6 +822,9 @@ def add_datasample(self, wait_time=wait_time) if out_file is not None: + # check the suffix of the name of image file + if not (out_file.endswith('.png') or out_file.endswith('.jpg')): + out_file = f'{out_file}.png' if drawn_img_3d is not None: mmcv.imwrite(drawn_img_3d[..., ::-1], out_file) if drawn_img is not None: diff --git a/tests/data/kitti/training/calib/000000.pkl b/tests/data/kitti/training/calib/000000.pkl new file mode 100644 index 0000000000..0937a57a33 Binary files /dev/null and b/tests/data/kitti/training/calib/000000.pkl differ diff --git a/tests/test_apis/test_inferencers/test_lidar_det3d_inferencer.py b/tests/test_apis/test_inferencers/test_lidar_det3d_inferencer.py index 5bcb2148b5..7d76b627ac 100644 --- a/tests/test_apis/test_inferencers/test_lidar_det3d_inferencer.py +++ b/tests/test_apis/test_inferencers/test_lidar_det3d_inferencer.py @@ -43,7 +43,7 @@ def assert_predictions_equal(self, preds1, preds2): def test_call(self): if not torch.cuda.is_available(): return - # single img + # single point cloud inputs = dict(points='tests/data/kitti/training/velodyne/000000.bin') res_path = self.inferencer(inputs, return_vis=True) # ndarray @@ -58,7 +58,7 @@ def test_call(self): self.assertIn('visualization', res_path) self.assertIn('visualization', res_ndarray) - # multiple images + # multiple point clouds inputs = [ dict(points='tests/data/kitti/training/velodyne/000000.bin'), dict(points='tests/data/kitti/training/velodyne/000000.bin') diff --git a/tests/test_apis/test_inferencers/test_multi_modality_det3d_inferencer.py b/tests/test_apis/test_inferencers/test_multi_modality_det3d_inferencer.py new file mode 100644 index 0000000000..e0be781ca2 --- /dev/null +++ b/tests/test_apis/test_inferencers/test_multi_modality_det3d_inferencer.py @@ -0,0 +1,120 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +import tempfile +from unittest import TestCase + +import mmcv +import mmengine +import numpy as np +import torch +from mmengine.utils import is_list_of + +from mmdet3d.apis import MultiModalityDet3DInferencer +from mmdet3d.structures import Det3DDataSample + + +class TestMultiModalityDet3DInferencer(TestCase): + + def setUp(self): + # init from alias + self.inferencer = MultiModalityDet3DInferencer('mvxnet_kitti-3class') + + def test_init(self): + # init from metafile + MultiModalityDet3DInferencer('mvxnet_kitti-3class') + # init from cfg + MultiModalityDet3DInferencer( + 'configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py', # noqa + weights= # noqa + 'https://download.openmmlab.com/mmdetection3d/v1.0.0_models/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class_20210831_060805-83442923.pth' # noqa + ) + + def assert_predictions_equal(self, preds1, preds2): + for pred1, pred2 in zip(preds1, preds2): + if 'bboxes_3d' in pred1: + self.assertTrue( + np.allclose(pred1['bboxes_3d'], pred2['bboxes_3d'], 0.1)) + if 'scores_3d' in pred1: + self.assertTrue( + np.allclose(pred1['scores_3d'], pred2['scores_3d'], 0.1)) + if 'labels_3d' in pred1: + self.assertTrue( + np.allclose(pred1['labels_3d'], pred2['labels_3d'])) + + def test_call(self): + if not torch.cuda.is_available(): + return + calib_path = 'tests/data/kitti/training/calib/000000.pkl' + points_path = 'tests/data/kitti/training/velodyne/000000.bin' + img_path = 'tests/data/kitti/training/image_2/000000.png' + # single img & point cloud + inputs = dict(points=points_path, img=img_path, calib=calib_path) + res_path = self.inferencer(inputs, return_vis=True) + + # ndarray + pts_bytes = mmengine.fileio.get(inputs['points']) + points = np.frombuffer(pts_bytes, dtype=np.float32) + points = points.reshape(-1, 4) + points = points[:, :4] + img = mmcv.imread(inputs['img']) + inputs = dict(points=points, img=img, calib=calib_path) + res_ndarray = self.inferencer(inputs, return_vis=True) + self.assert_predictions_equal(res_path['predictions'], + res_ndarray['predictions']) + self.assertIn('visualization', res_path) + self.assertIn('visualization', res_ndarray) + + # multiple imgs & point clouds + inputs = [ + dict(points=points_path, img=img_path, calib=calib_path), + dict(points=points_path, img=img_path, calib=calib_path) + ] + res_path = self.inferencer(inputs, return_vis=True) + # list of ndarray + all_inputs = [] + for p in inputs: + pts_bytes = mmengine.fileio.get(p['points']) + points = np.frombuffer(pts_bytes, dtype=np.float32) + points = points.reshape(-1, 4) + img = mmcv.imread(p['img']) + all_inputs.append(dict(points=points, img=img, calib=p['calib'])) + + res_ndarray = self.inferencer(all_inputs, return_vis=True) + self.assert_predictions_equal(res_path['predictions'], + res_ndarray['predictions']) + self.assertIn('visualization', res_path) + self.assertIn('visualization', res_ndarray) + + def test_visualize(self): + if not torch.cuda.is_available(): + return + inputs = dict( + points='tests/data/kitti/training/velodyne/000000.bin', + img='tests/data/kitti/training/image_2/000000.png', + calib='tests/data/kitti/training/calib/000000.pkl'), + # img_out_dir + with tempfile.TemporaryDirectory() as tmp_dir: + self.inferencer(inputs, img_out_dir=tmp_dir) + # TODO: For results of LiDAR-based detection, the saved image only + # exists when show=True. + # self.assertTrue(osp.exists(osp.join(tmp_dir, '000000.png'))) + + def test_postprocess(self): + if not torch.cuda.is_available(): + return + # return_datasample + inputs = dict( + points='tests/data/kitti/training/velodyne/000000.bin', + img='tests/data/kitti/training/image_2/000000.png', + calib='tests/data/kitti/training/calib/000000.pkl') + res = self.inferencer(inputs, return_datasamples=True) + self.assertTrue(is_list_of(res['predictions'], Det3DDataSample)) + + # pred_out_file + with tempfile.TemporaryDirectory() as tmp_dir: + pred_out_file = osp.join(tmp_dir, 'tmp.json') + res = self.inferencer( + inputs, print_result=True, pred_out_file=pred_out_file) + dumped_res = mmengine.load(pred_out_file) + self.assert_predictions_equal(res['predictions'], + dumped_res['predictions'])